summaryrefslogtreecommitdiff
path: root/gnu
diff options
context:
space:
mode:
authorgkoehler <gkoehler@cvs.openbsd.org>2020-05-02 05:34:09 +0000
committergkoehler <gkoehler@cvs.openbsd.org>2020-05-02 05:34:09 +0000
commitd59f6f96219aaf0db8324a26b289ca35466ecbb2 (patch)
tree0e77c5b5cf783edeababae1d5e0275cd94c3f54d /gnu
parentd6eb54220157a9afd74dab08a6e7895f2a7928ad (diff)
Don't make an illegal adde. Avoids fatal error on PowerPC.
When the DAG truncates an ISD::ADDE node, DAGCombiner may optimize it by making an adde with smaller operands. PowerPC has i1 registers, and may truncate an i32 adde to i1, but an i1 adde is not legal for PowerPC, and the legalize-ops phase can't fix it. This was causing "fatal error: error in backend: Cannot select..." cwen@ reported the error ok mortimer@ kettenis@ deraadt@
Diffstat (limited to 'gnu')
-rw-r--r--gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp13857
1 files changed, 9223 insertions, 4634 deletions
diff --git a/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 210aa95b02d..8203476ed4e 100644
--- a/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1,4 +1,4 @@
-//===-- DAGCombiner.cpp - Implement a DAG node combiner -------------------===//
+//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
//
// The LLVM Compiler Infrastructure
//
@@ -16,28 +16,64 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/CodeGen/DAGCombine.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/RuntimeLibcalls.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constant.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MachineValueType.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Target/TargetLowering.h"
+#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
-#include "llvm/Target/TargetRegisterInfo.h"
-#include "llvm/Target/TargetSubtargetInfo.h"
#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <string>
+#include <tuple>
+#include <utility>
+
using namespace llvm;
#define DEBUG_TYPE "dagcombine"
@@ -48,51 +84,46 @@ STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
STATISTIC(SlicedLoads, "Number of load sliced");
+STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
-namespace {
- static cl::opt<bool>
- CombinerAA("combiner-alias-analysis", cl::Hidden,
- cl::desc("Enable DAG combiner alias-analysis heuristics"));
+static cl::opt<bool>
+CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
+ cl::desc("Enable DAG combiner's use of IR alias analysis"));
- static cl::opt<bool>
- CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
- cl::desc("Enable DAG combiner's use of IR alias analysis"));
-
- static cl::opt<bool>
- UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
- cl::desc("Enable DAG combiner's use of TBAA"));
+static cl::opt<bool>
+UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
+ cl::desc("Enable DAG combiner's use of TBAA"));
#ifndef NDEBUG
- static cl::opt<std::string>
- CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
- cl::desc("Only use DAG-combiner alias analysis in this"
- " function"));
+static cl::opt<std::string>
+CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
+ cl::desc("Only use DAG-combiner alias analysis in this"
+ " function"));
#endif
- /// Hidden option to stress test load slicing, i.e., when this option
- /// is enabled, load slicing bypasses most of its profitability guards.
- static cl::opt<bool>
- StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
- cl::desc("Bypass the profitability model of load "
- "slicing"),
- cl::init(false));
+/// Hidden option to stress test load slicing, i.e., when this option
+/// is enabled, load slicing bypasses most of its profitability guards.
+static cl::opt<bool>
+StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
+ cl::desc("Bypass the profitability model of load slicing"),
+ cl::init(false));
- static cl::opt<bool>
- MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
- cl::desc("DAG combiner may split indexing from loads"));
+static cl::opt<bool>
+ MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
+ cl::desc("DAG combiner may split indexing from loads"));
-//------------------------------ DAGCombiner ---------------------------------//
+namespace {
class DAGCombiner {
SelectionDAG &DAG;
const TargetLowering &TLI;
CombineLevel Level;
CodeGenOpt::Level OptLevel;
- bool LegalOperations;
- bool LegalTypes;
+ bool LegalOperations = false;
+ bool LegalTypes = false;
bool ForCodeSize;
- /// \brief Worklist of all of the nodes that need to be simplified.
+ /// Worklist of all of the nodes that need to be simplified.
///
/// This must behave as a stack -- new nodes to process are pushed onto the
/// back and when processing we pop off of the back.
@@ -101,21 +132,21 @@ namespace {
/// due to nodes being deleted from the underlying DAG.
SmallVector<SDNode *, 64> Worklist;
- /// \brief Mapping from an SDNode to its position on the worklist.
+ /// Mapping from an SDNode to its position on the worklist.
///
/// This is used to find and remove nodes from the worklist (by nulling
/// them) when they are deleted from the underlying DAG. It relies on
/// stable indices of nodes within the worklist.
DenseMap<SDNode *, unsigned> WorklistMap;
- /// \brief Set of nodes which have been combined (at least once).
+ /// Set of nodes which have been combined (at least once).
///
/// This is used to allow us to reliably add any operands of a DAG node
/// which have not yet been combined to the worklist.
- SmallPtrSet<SDNode *, 64> CombinedNodes;
+ SmallPtrSet<SDNode *, 32> CombinedNodes;
// AA - Used for DAG load/store alias analysis.
- AliasAnalysis &AA;
+ AliasAnalysis *AA;
/// When an instruction is simplified, add all users of the instruction to
/// the work lists because they might get more simplified now.
@@ -128,9 +159,25 @@ namespace {
SDValue visit(SDNode *N);
public:
+ DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
+ : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes),
+ OptLevel(OL), AA(AA) {
+ ForCodeSize = DAG.getMachineFunction().getFunction().optForSize();
+
+ MaximumLegalStoreInBits = 0;
+ for (MVT VT : MVT::all_valuetypes())
+ if (EVT(VT).isSimple() && VT != MVT::Other &&
+ TLI.isTypeLegal(EVT(VT)) &&
+ VT.getSizeInBits() >= MaximumLegalStoreInBits)
+ MaximumLegalStoreInBits = VT.getSizeInBits();
+ }
+
/// Add to the worklist making sure its instance is at the back (next to be
/// processed.)
void AddToWorklist(SDNode *N) {
+ assert(N->getOpcode() != ISD::DELETED_NODE &&
+ "Deleted Node added to Worklist");
+
// Skip handle nodes as they can't usefully be combined and confuse the
// zero-use deletion strategy.
if (N->getOpcode() == ISD::HANDLENODE)
@@ -175,24 +222,41 @@ namespace {
void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
private:
+ unsigned MaximumLegalStoreInBits;
/// Check the specified integer node value to see if it can be simplified or
/// if things it uses can be simplified by bit propagation.
/// If so, return true.
bool SimplifyDemandedBits(SDValue Op) {
- unsigned BitWidth = Op.getValueType().getScalarType().getSizeInBits();
+ unsigned BitWidth = Op.getScalarValueSizeInBits();
APInt Demanded = APInt::getAllOnesValue(BitWidth);
return SimplifyDemandedBits(Op, Demanded);
}
+ /// Check the specified vector node value to see if it can be simplified or
+ /// if things it uses can be simplified as it only uses some of the
+ /// elements. If so, return true.
+ bool SimplifyDemandedVectorElts(SDValue Op) {
+ unsigned NumElts = Op.getValueType().getVectorNumElements();
+ APInt Demanded = APInt::getAllOnesValue(NumElts);
+ return SimplifyDemandedVectorElts(Op, Demanded);
+ }
+
bool SimplifyDemandedBits(SDValue Op, const APInt &Demanded);
+ bool SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded,
+ bool AssumeSingleUse = false);
bool CombineToPreIndexedLoadStore(SDNode *N);
bool CombineToPostIndexedLoadStore(SDNode *N);
SDValue SplitIndexingFromLoad(LoadSDNode *LD);
bool SliceUpLoad(SDNode *N);
- /// \brief Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
+ // Scalars have size 0 to distinguish from singleton vectors.
+ SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
+ bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
+ bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
+
+ /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
/// load.
///
/// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
@@ -200,8 +264,9 @@ namespace {
/// \param EltNo index of the vector element to load.
/// \param OriginalLoad load that EVE came from to be replaced.
/// \returns EVE on success SDValue() on failure.
- SDValue ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
- SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad);
+ SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
+ SDValue EltNo,
+ LoadSDNode *OriginalLoad);
void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
@@ -211,10 +276,6 @@ namespace {
SDValue PromoteExtend(SDValue Op);
bool PromoteLoad(SDValue Op);
- void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
- SDValue Trunc, SDValue ExtLoad, SDLoc DL,
- ISD::NodeType ExtType);
-
/// Call the node-specific routine that knows how to fold each
/// particular type of node. If that doesn't do anything, try the
/// target-specific DAG combines.
@@ -230,15 +291,26 @@ namespace {
SDValue visitTokenFactor(SDNode *N);
SDValue visitMERGE_VALUES(SDNode *N);
SDValue visitADD(SDNode *N);
+ SDValue visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference);
SDValue visitSUB(SDNode *N);
+ SDValue visitADDSAT(SDNode *N);
+ SDValue visitSUBSAT(SDNode *N);
SDValue visitADDC(SDNode *N);
+ SDValue visitUADDO(SDNode *N);
+ SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitSUBC(SDNode *N);
+ SDValue visitUSUBO(SDNode *N);
SDValue visitADDE(SDNode *N);
+ SDValue visitADDCARRY(SDNode *N);
+ SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
SDValue visitSUBE(SDNode *N);
+ SDValue visitSUBCARRY(SDNode *N);
SDValue visitMUL(SDNode *N);
SDValue useDivRem(SDNode *N);
SDValue visitSDIV(SDNode *N);
+ SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitUDIV(SDNode *N);
+ SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitREM(SDNode *N);
SDValue visitMULHU(SDNode *N);
SDValue visitMULHS(SDNode *N);
@@ -248,16 +320,19 @@ namespace {
SDValue visitUMULO(SDNode *N);
SDValue visitIMINMAX(SDNode *N);
SDValue visitAND(SDNode *N);
- SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *LocReference);
+ SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitOR(SDNode *N);
- SDValue visitORLike(SDValue N0, SDValue N1, SDNode *LocReference);
+ SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitXOR(SDNode *N);
SDValue SimplifyVBinOp(SDNode *N);
SDValue visitSHL(SDNode *N);
SDValue visitSRA(SDNode *N);
SDValue visitSRL(SDNode *N);
+ SDValue visitFunnelShift(SDNode *N);
SDValue visitRotate(SDNode *N);
+ SDValue visitABS(SDNode *N);
SDValue visitBSWAP(SDNode *N);
+ SDValue visitBITREVERSE(SDNode *N);
SDValue visitCTLZ(SDNode *N);
SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
SDValue visitCTTZ(SDNode *N);
@@ -267,12 +342,14 @@ namespace {
SDValue visitVSELECT(SDNode *N);
SDValue visitSELECT_CC(SDNode *N);
SDValue visitSETCC(SDNode *N);
- SDValue visitSETCCE(SDNode *N);
+ SDValue visitSETCCCARRY(SDNode *N);
SDValue visitSIGN_EXTEND(SDNode *N);
SDValue visitZERO_EXTEND(SDNode *N);
SDValue visitANY_EXTEND(SDNode *N);
+ SDValue visitAssertExt(SDNode *N);
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N);
+ SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
@@ -284,6 +361,7 @@ namespace {
SDValue visitFREM(SDNode *N);
SDValue visitFSQRT(SDNode *N);
SDValue visitFCOPYSIGN(SDNode *N);
+ SDValue visitFPOW(SDNode *N);
SDValue visitSINT_TO_FP(SDNode *N);
SDValue visitUINT_TO_FP(SDNode *N);
SDValue visitFP_TO_SINT(SDNode *N);
@@ -298,6 +376,8 @@ namespace {
SDValue visitFFLOOR(SDNode *N);
SDValue visitFMINNUM(SDNode *N);
SDValue visitFMAXNUM(SDNode *N);
+ SDValue visitFMINIMUM(SDNode *N);
+ SDValue visitFMAXIMUM(SDNode *N);
SDValue visitBRCOND(SDNode *N);
SDValue visitBR_CC(SDNode *N);
SDValue visitLOAD(SDNode *N);
@@ -323,21 +403,35 @@ namespace {
SDValue visitFADDForFMACombine(SDNode *N);
SDValue visitFSUBForFMACombine(SDNode *N);
- SDValue visitFMULForFMACombine(SDNode *N);
+ SDValue visitFMULForFMADistributiveCombine(SDNode *N);
SDValue XformToShuffleWithZero(SDNode *N);
- SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS);
+ SDValue ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
+ SDValue N1, SDNodeFlags Flags);
SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt);
+ SDValue foldSelectOfConstants(SDNode *N);
+ SDValue foldVSelectOfConstants(SDNode *N);
+ SDValue foldBinOpIntoSelect(SDNode *BO);
bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
- SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N);
- SDValue SimplifySelect(SDLoc DL, SDValue N0, SDValue N1, SDValue N2);
- SDValue SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, SDValue N2,
- SDValue N3, ISD::CondCode CC,
+ SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
+ SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
+ SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
+ SDValue N2, SDValue N3, ISD::CondCode CC,
bool NotExtCompare = false);
+ SDValue convertSelectOfFPConstantsToLoadOffset(
+ const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
+ ISD::CondCode CC);
+ SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
+ SDValue N2, SDValue N3, ISD::CondCode CC);
+ SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
+ const SDLoc &DL);
+ SDValue unfoldMaskedMerge(SDNode *N);
+ SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
- SDLoc DL, bool foldBooleans = true);
+ const SDLoc &DL, bool foldBooleans);
+ SDValue rebuildSetCC(SDValue N);
bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
SDValue &CC) const;
@@ -347,32 +441,42 @@ namespace {
unsigned HiOp);
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue CombineExtLoad(SDNode *N);
+ SDValue CombineZExtLogicopShiftLoad(SDNode *N);
SDValue combineRepeatedFPDivisors(SDNode *N);
+ SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
SDValue BuildSDIV(SDNode *N);
SDValue BuildSDIVPow2(SDNode *N);
SDValue BuildUDIV(SDNode *N);
- SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags);
- SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags);
- SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations,
- SDNodeFlags *Flags);
- SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations,
- SDNodeFlags *Flags);
+ SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
+ SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags);
+ SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
+ SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
+ SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
+ SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
+ SDNodeFlags Flags, bool Reciprocal);
+ SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
+ SDNodeFlags Flags, bool Reciprocal);
SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
bool DemandHighBits = true);
SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
SDNode *MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
SDValue InnerPos, SDValue InnerNeg,
unsigned PosOpcode, unsigned NegOpcode,
- SDLoc DL);
- SDNode *MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL);
+ const SDLoc &DL);
+ SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
+ SDValue MatchLoadCombine(SDNode *N);
SDValue ReduceLoadWidth(SDNode *N);
SDValue ReduceLoadOpStoreWidth(SDNode *N);
+ SDValue splitMergedValStore(StoreSDNode *ST);
SDValue TransformFPLoadStorePair(SDNode *N);
+ SDValue convertBuildVecZextToZext(SDNode *N);
SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
- SDValue reduceBuildVecConvertToConvertBuildVec(SDNode *N);
-
- SDValue GetDemandedBits(SDValue V, const APInt &Mask);
+ SDValue reduceBuildVecToShuffle(SDNode *N);
+ SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
+ ArrayRef<int> VectorMask, SDValue VecIn1,
+ SDValue VecIn2, unsigned LeftIdx);
+ SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
/// Walk up chain skipping non-aliasing memory nodes,
/// looking for aliasing nodes and adding them to the Aliases vector.
@@ -386,22 +490,29 @@ namespace {
/// chain (aliasing node.)
SDValue FindBetterChain(SDNode *N, SDValue Chain);
- /// Do FindBetterChain for a store and any possibly adjacent stores on
- /// consecutive chains.
+ /// Try to replace a store and any possibly adjacent stores on
+ /// consecutive chains with better chains. Return true only if St is
+ /// replaced.
+ ///
+ /// Notice that other chains may still be replaced even if the function
+ /// returns false.
bool findBetterNeighborChains(StoreSDNode *St);
+ // Helper for findBetterNeighborChains. Walk up store chain add additional
+ // chained stores that do not overlap and can be parallelized.
+ bool parallelizeChainedStores(StoreSDNode *St);
+
/// Holds a pointer to an LSBaseSDNode as well as information on where it
/// is located in a sequence of memory operations connected by a chain.
struct MemOpLink {
- MemOpLink (LSBaseSDNode *N, int64_t Offset, unsigned Seq):
- MemNode(N), OffsetFromBase(Offset), SequenceNum(Seq) { }
// Ptr to the mem node.
LSBaseSDNode *MemNode;
+
// Offset from the base ptr.
int64_t OffsetFromBase;
- // What is the sequence number of this mem node.
- // Lowest mem operand in the DAG starts at zero.
- unsigned SequenceNum;
+
+ MemOpLink(LSBaseSDNode *N, int64_t Offset)
+ : MemNode(N), OffsetFromBase(Offset) {}
};
/// This is a helper function for visitMUL to check the profitability
@@ -412,44 +523,64 @@ namespace {
SDValue &AddNode,
SDValue &ConstNode);
- /// This is a helper function for MergeStoresOfConstantsOrVecElts. Returns a
- /// constant build_vector of the stored constant values in Stores.
- SDValue getMergedConstantVectorStore(SelectionDAG &DAG,
- SDLoc SL,
- ArrayRef<MemOpLink> Stores,
- SmallVectorImpl<SDValue> &Chains,
- EVT Ty) const;
-
/// This is a helper function for visitAND and visitZERO_EXTEND. Returns
/// true if the (and (load x) c) pattern matches an extload. ExtVT returns
- /// the type of the loaded value to be extended. LoadedVT returns the type
- /// of the original loaded value. NarrowLoad returns whether the load would
- /// need to be narrowed in order to match.
+ /// the type of the loaded value to be extended.
bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
- EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT,
- bool &NarrowLoad);
-
- /// This is a helper function for MergeConsecutiveStores. When the source
- /// elements of the consecutive stores are all constants or all extracted
- /// vector elements, try to merge them into one larger store.
- /// \return True if a merged store was created.
+ EVT LoadResultTy, EVT &ExtVT);
+
+ /// Helper function to calculate whether the given Load/Store can have its
+ /// width reduced to ExtVT.
+ bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
+ EVT &MemVT, unsigned ShAmt = 0);
+
+ /// Used by BackwardsPropagateMask to find suitable loads.
+ bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
+ SmallPtrSetImpl<SDNode*> &NodesWithConsts,
+ ConstantSDNode *Mask, SDNode *&NodeToMask);
+ /// Attempt to propagate a given AND node back to load leaves so that they
+ /// can be combined into narrow loads.
+ bool BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG);
+
+ /// Helper function for MergeConsecutiveStores which merges the
+ /// component store chains.
+ SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
+ unsigned NumStores);
+
+ /// This is a helper function for MergeConsecutiveStores. When the
+ /// source elements of the consecutive stores are all constants or
+ /// all extracted vector elements, try to merge them into one
+ /// larger store introducing bitcasts if necessary. \return True
+ /// if a merged store was created.
bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
EVT MemVT, unsigned NumStores,
- bool IsConstantSrc, bool UseVector);
-
- /// This is a helper function for MergeConsecutiveStores.
- /// Stores that may be merged are placed in StoreNodes.
- /// Loads that may alias with those stores are placed in AliasLoadNodes.
- void getStoreMergeAndAliasCandidates(
- StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes,
- SmallVectorImpl<LSBaseSDNode*> &AliasLoadNodes);
+ bool IsConstantSrc, bool UseVector,
+ bool UseTrunc);
+
+ /// This is a helper function for MergeConsecutiveStores. Stores
+ /// that potentially may be merged with St are placed in
+ /// StoreNodes. RootNode is a chain predecessor to all store
+ /// candidates.
+ void getStoreMergeCandidates(StoreSDNode *St,
+ SmallVectorImpl<MemOpLink> &StoreNodes,
+ SDNode *&Root);
+
+ /// Helper function for MergeConsecutiveStores. Checks if
+ /// candidate stores have indirect dependency through their
+ /// operands. RootNode is the predecessor to all stores calculated
+ /// by getStoreMergeCandidates and is used to prune the dependency check.
+ /// \return True if safe to merge.
+ bool checkMergeStoreCandidatesForDependencies(
+ SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
+ SDNode *RootNode);
/// Merge consecutive store operations into a wide store.
/// This optimization uses wide integers or vectors when possible.
- /// \return True if some memory operations were changed.
- bool MergeConsecutiveStores(StoreSDNode *N);
+ /// \return number of stores that were merged into a merged store (the
+ /// affected nodes are stored as a prefix in \p StoreNodes).
+ bool MergeConsecutiveStores(StoreSDNode *St);
- /// \brief Try to transform a truncation where C is a constant:
+ /// Try to transform a truncation where C is a constant:
/// (trunc (and X, C)) -> (and (trunc X), (trunc C))
///
/// \p N needs to be a truncation and its first operand an AND. Other
@@ -457,13 +588,17 @@ namespace {
/// single-use) and if missed an empty SDValue is returned.
SDValue distributeTruncateThroughAnd(SDNode *N);
- public:
- DAGCombiner(SelectionDAG &D, AliasAnalysis &A, CodeGenOpt::Level OL)
- : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes),
- OptLevel(OL), LegalOperations(false), LegalTypes(false), AA(A) {
- ForCodeSize = DAG.getMachineFunction().getFunction()->optForSize();
+ /// Helper function to determine whether the target supports operation
+ /// given by \p Opcode for type \p VT, that is, whether the operation
+ /// is legal or custom before legalizing operations, and whether is
+ /// legal (but not custom) after legalization.
+ bool hasOperation(unsigned Opcode, EVT VT) {
+ if (LegalOperations)
+ return TLI.isOperationLegal(Opcode, VT);
+ return TLI.isOperationLegalOrCustom(Opcode, VT);
}
+ public:
/// Runs the dag combiner on all nodes in the work list
void Run(CombineLevel AtLevel);
@@ -473,11 +608,7 @@ namespace {
/// legalization these can be huge.
EVT getShiftAmountTy(EVT LHSTy) {
assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
- if (LHSTy.isVector())
- return LHSTy;
- auto &DL = DAG.getDataLayout();
- return LegalTypes ? TLI.getScalarShiftAmountTy(DL, LHSTy)
- : TLI.getPointerTy(DL);
+ return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
}
/// This method returns true if we are running before type legalization or
@@ -491,15 +622,17 @@ namespace {
EVT getSetCCResultType(EVT VT) const {
return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
}
- };
-}
+ void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
+ SDValue OrigLoad, SDValue ExtLoad,
+ ISD::NodeType ExtType);
+ };
-namespace {
/// This class is a DAGUpdateListener that removes any deleted
/// nodes from the worklist.
class WorklistRemover : public SelectionDAG::DAGUpdateListener {
DAGCombiner &DC;
+
public:
explicit WorklistRemover(DAGCombiner &dc)
: SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
@@ -508,7 +641,8 @@ public:
DC.removeFromWorklist(N);
}
};
-}
+
+} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TargetLowering::DAGCombinerInfo implementation
@@ -518,10 +652,6 @@ void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
((DAGCombiner*)DC)->AddToWorklist(N);
}
-void TargetLowering::DAGCombinerInfo::RemoveFromWorklist(SDNode *N) {
- ((DAGCombiner*)DC)->removeFromWorklist(N);
-}
-
SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
@@ -532,7 +662,6 @@ CombineTo(SDNode *N, SDValue Res, bool AddTo) {
return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
}
-
SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
@@ -572,25 +701,34 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations,
// fneg is removable even if it has multiple uses.
if (Op.getOpcode() == ISD::FNEG) return 2;
- // Don't allow anything with multiple uses.
- if (!Op.hasOneUse()) return 0;
+ // Don't allow anything with multiple uses unless we know it is free.
+ EVT VT = Op.getValueType();
+ const SDNodeFlags Flags = Op->getFlags();
+ if (!Op.hasOneUse())
+ if (!(Op.getOpcode() == ISD::FP_EXTEND &&
+ TLI.isFPExtFree(VT, Op.getOperand(0).getValueType())))
+ return 0;
// Don't recurse exponentially.
if (Depth > 6) return 0;
switch (Op.getOpcode()) {
default: return false;
- case ISD::ConstantFP:
- // Don't invert constant FP values after legalize. The negated constant
- // isn't necessarily legal.
- return LegalOperations ? 0 : 1;
+ case ISD::ConstantFP: {
+ if (!LegalOperations)
+ return 1;
+
+ // Don't invert constant FP values after legalization unless the target says
+ // the negated constant is legal.
+ return TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ TLI.isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT);
+ }
case ISD::FADD:
- // FIXME: determine better conditions for this xform.
- if (!Options->UnsafeFPMath) return 0;
+ if (!Options->UnsafeFPMath && !Flags.hasNoSignedZeros())
+ return 0;
// After operation legalization, it might not be legal to create new FSUBs.
- if (LegalOperations &&
- !TLI.isOperationLegalOrCustom(ISD::FSUB, Op.getValueType()))
+ if (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
return 0;
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
@@ -602,15 +740,15 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations,
Depth + 1);
case ISD::FSUB:
// We can't turn -(A-B) into B-A when we honor signed zeros.
- if (!Options->UnsafeFPMath) return 0;
+ if (!Options->NoSignedZerosFPMath &&
+ !Flags.hasNoSignedZeros())
+ return 0;
// fold (fneg (fsub A, B)) -> (fsub B, A)
return 1;
case ISD::FMUL:
case ISD::FDIV:
- if (Options->HonorSignDependentRoundingFPMath()) return 0;
-
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
if (char V = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI,
Options, Depth + 1))
@@ -634,12 +772,9 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
// fneg is removable even if it has multiple uses.
if (Op.getOpcode() == ISD::FNEG) return Op.getOperand(0);
- // Don't allow anything with multiple uses.
- assert(Op.hasOneUse() && "Unknown reuse!");
-
assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree");
- const SDNodeFlags *Flags = Op.getNode()->getFlags();
+ const SDNodeFlags Flags = Op.getNode()->getFlags();
switch (Op.getOpcode()) {
default: llvm_unreachable("Unknown code");
@@ -649,8 +784,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
return DAG.getConstantFP(V, SDLoc(Op), Op.getValueType());
}
case ISD::FADD:
- // FIXME: determine better conditions for this xform.
- assert(Options.UnsafeFPMath);
+ assert(Options.UnsafeFPMath || Flags.hasNoSignedZeros());
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
if (isNegatibleForFree(Op.getOperand(0), LegalOperations,
@@ -665,9 +799,6 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
LegalOperations, Depth+1),
Op.getOperand(0), Flags);
case ISD::FSUB:
- // We can't turn -(A-B) into B-A when we honor signed zeros.
- assert(Options.UnsafeFPMath);
-
// fold (fneg (fsub 0, B)) -> B
if (ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Op.getOperand(0)))
if (N0CFP->isZero())
@@ -679,8 +810,6 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
case ISD::FMUL:
case ISD::FDIV:
- assert(!Options.HonorSignDependentRoundingFPMath());
-
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
if (isNegatibleForFree(Op.getOperand(0), LegalOperations,
DAG.getTargetLoweringInfo(), &Options, Depth+1))
@@ -708,6 +837,15 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
}
}
+// APInts must be the same size for most operations, this helper
+// function zero extends the shorter of the pair so that they match.
+// We provide an Offset so that we can create bitwidths that won't overflow.
+static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
+ unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
+ LHS = LHS.zextOrSelf(Bits);
+ RHS = RHS.zextOrSelf(Bits);
+}
+
// Return true if this node is a setcc, or is a select_cc
// that selects between the target values used for true and false, making it
// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
@@ -747,33 +885,7 @@ bool DAGCombiner::isOneUseSetCC(SDValue N) const {
return false;
}
-/// Returns true if N is a BUILD_VECTOR node whose
-/// elements are all the same constant or undefined.
-static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) {
- BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N);
- if (!C)
- return false;
-
- APInt SplatUndef;
- unsigned SplatBitSize;
- bool HasAnyUndefs;
- EVT EltVT = N->getValueType(0).getVectorElementType();
- return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
- HasAnyUndefs) &&
- EltVT.getSizeInBits() >= SplatBitSize);
-}
-
-// \brief Returns the SDNode if it is a constant integer BuildVector
-// or constant integer.
-static SDNode *isConstantIntBuildVectorOrConstantInt(SDValue N) {
- if (isa<ConstantSDNode>(N))
- return N.getNode();
- if (ISD::isBuildVectorOfConstantSDNodes(N.getNode()))
- return N.getNode();
- return nullptr;
-}
-
-// \brief Returns the SDNode if it is a constant float BuildVector
+// Returns the SDNode if it is a constant float BuildVector
// or constant float.
static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) {
if (isa<ConstantFPSDNode>(N))
@@ -783,50 +895,45 @@ static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) {
return nullptr;
}
-// \brief Returns the SDNode if it is a constant splat BuildVector or constant
-// int.
-static ConstantSDNode *isConstOrConstSplat(SDValue N) {
- if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
- return CN;
-
- if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
- BitVector UndefElements;
- ConstantSDNode *CN = BV->getConstantSplatNode(&UndefElements);
-
- // BuildVectors can truncate their operands. Ignore that case here.
- // FIXME: We blindly ignore splats which include undef which is overly
- // pessimistic.
- if (CN && UndefElements.none() &&
- CN->getValueType(0) == N.getValueType().getScalarType())
- return CN;
+// Determines if it is a constant integer or a build vector of constant
+// integers (and undefs).
+// Do not permit build vector implicit truncation.
+static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
+ if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
+ return !(Const->isOpaque() && NoOpaques);
+ if (N.getOpcode() != ISD::BUILD_VECTOR)
+ return false;
+ unsigned BitWidth = N.getScalarValueSizeInBits();
+ for (const SDValue &Op : N->op_values()) {
+ if (Op.isUndef())
+ continue;
+ ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
+ if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
+ (Const->isOpaque() && NoOpaques))
+ return false;
}
-
- return nullptr;
+ return true;
}
-// \brief Returns the SDNode if it is a constant splat BuildVector or constant
-// float.
-static ConstantFPSDNode *isConstOrConstSplatFP(SDValue N) {
- if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
- return CN;
-
- if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
- BitVector UndefElements;
- ConstantFPSDNode *CN = BV->getConstantFPSplatNode(&UndefElements);
-
- if (CN && UndefElements.none())
- return CN;
- }
-
- return nullptr;
+// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
+// undef's.
+static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
+ if (V.getOpcode() != ISD::BUILD_VECTOR)
+ return false;
+ return isConstantOrConstantVector(V, NoOpaques) ||
+ ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
}
-SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL,
- SDValue N0, SDValue N1) {
+SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
+ SDValue N1, SDNodeFlags Flags) {
+ // Don't reassociate reductions.
+ if (Flags.hasVectorReduction())
+ return SDValue();
+
EVT VT = N0.getValueType();
- if (N0.getOpcode() == Opc) {
- if (SDNode *L = isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) {
- if (SDNode *R = isConstantIntBuildVectorOrConstantInt(N1)) {
+ if (N0.getOpcode() == Opc && !N0->getFlags().hasVectorReduction()) {
+ if (SDNode *L = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) {
+ if (SDNode *R = DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
// reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2))
if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, L, R))
return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
@@ -844,18 +951,18 @@ SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL,
}
}
- if (N1.getOpcode() == Opc) {
- if (SDNode *R = isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) {
- if (SDNode *L = isConstantIntBuildVectorOrConstantInt(N0)) {
+ if (N1.getOpcode() == Opc && !N1->getFlags().hasVectorReduction()) {
+ if (SDNode *R = DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) {
+ if (SDNode *L = DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
// reassoc. (op c2, (op x, c1)) -> (op x, (op c1, c2))
if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, R, L))
return DAG.getNode(Opc, DL, VT, N1.getOperand(0), OpNode);
return SDValue();
}
if (N1.hasOneUse()) {
- // reassoc. (op y, (op x, c1)) -> (op (op x, y), c1) iff x+c1 has one
+ // reassoc. (op x, (op y, c1)) -> (op (op x, y), c1) iff x+c1 has one
// use
- SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N1.getOperand(0), N0);
+ SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0, N1.getOperand(0));
if (!OpNode.getNode())
return SDValue();
AddToWorklist(OpNode.getNode());
@@ -871,11 +978,9 @@ SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
bool AddTo) {
assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
++NodesCombined;
- DEBUG(dbgs() << "\nReplacing.1 ";
- N->dump(&DAG);
- dbgs() << "\nWith: ";
- To[0].getNode()->dump(&DAG);
- dbgs() << " and " << NumTo-1 << " other values\n");
+ LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
+ To[0].getNode()->dump(&DAG);
+ dbgs() << " and " << NumTo - 1 << " other values\n");
for (unsigned i = 0, e = NumTo; i != e; ++i)
assert((!To[i].getNode() ||
N->getValueType(i) == To[i].getValueType()) &&
@@ -923,8 +1028,32 @@ CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
/// things it uses can be simplified by bit propagation. If so, return true.
bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) {
TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
- APInt KnownZero, KnownOne;
- if (!TLI.SimplifyDemandedBits(Op, Demanded, KnownZero, KnownOne, TLO))
+ KnownBits Known;
+ if (!TLI.SimplifyDemandedBits(Op, Demanded, Known, TLO))
+ return false;
+
+ // Revisit the node.
+ AddToWorklist(Op.getNode());
+
+ // Replace the old value with the new one.
+ ++NodesCombined;
+ LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
+ dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
+ dbgs() << '\n');
+
+ CommitTargetLoweringOpt(TLO);
+ return true;
+}
+
+/// Check the specified vector node value to see if it can be simplified or
+/// if things it uses can be simplified as it only uses some of the elements.
+/// If so, return true.
+bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded,
+ bool AssumeSingleUse) {
+ TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
+ APInt KnownUndef, KnownZero;
+ if (!TLI.SimplifyDemandedVectorElts(Op, Demanded, KnownUndef, KnownZero, TLO,
+ 0, AssumeSingleUse))
return false;
// Revisit the node.
@@ -932,26 +1061,21 @@ bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) {
// Replace the old value with the new one.
++NodesCombined;
- DEBUG(dbgs() << "\nReplacing.2 ";
- TLO.Old.getNode()->dump(&DAG);
- dbgs() << "\nWith: ";
- TLO.New.getNode()->dump(&DAG);
- dbgs() << '\n');
+ LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG);
+ dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG);
+ dbgs() << '\n');
CommitTargetLoweringOpt(TLO);
return true;
}
void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
- SDLoc dl(Load);
+ SDLoc DL(Load);
EVT VT = Load->getValueType(0);
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, VT, SDValue(ExtLoad, 0));
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
- DEBUG(dbgs() << "\nReplacing.9 ";
- Load->dump(&DAG);
- dbgs() << "\nWith: ";
- Trunc.getNode()->dump(&DAG);
- dbgs() << '\n');
+ LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
+ Trunc.getNode()->dump(&DAG); dbgs() << '\n');
WorklistRemover DeadNodes(*this);
DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
@@ -961,15 +1085,14 @@ void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
Replace = false;
- SDLoc dl(Op);
- if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Op)) {
+ SDLoc DL(Op);
+ if (ISD::isUNINDEXEDLoad(Op.getNode())) {
+ LoadSDNode *LD = cast<LoadSDNode>(Op);
EVT MemVT = LD->getMemoryVT();
- ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD)
- ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, PVT, MemVT) ? ISD::ZEXTLOAD
- : ISD::EXTLOAD)
- : LD->getExtensionType();
+ ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
+ : LD->getExtensionType();
Replace = true;
- return DAG.getExtLoad(ExtType, dl, PVT,
+ return DAG.getExtLoad(ExtType, DL, PVT,
LD->getChain(), LD->getBasePtr(),
MemVT, LD->getMemOperand());
}
@@ -978,30 +1101,30 @@ SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
switch (Opc) {
default: break;
case ISD::AssertSext:
- return DAG.getNode(ISD::AssertSext, dl, PVT,
- SExtPromoteOperand(Op.getOperand(0), PVT),
- Op.getOperand(1));
+ if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
+ return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
+ break;
case ISD::AssertZext:
- return DAG.getNode(ISD::AssertZext, dl, PVT,
- ZExtPromoteOperand(Op.getOperand(0), PVT),
- Op.getOperand(1));
+ if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
+ return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
+ break;
case ISD::Constant: {
unsigned ExtOpc =
Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
- return DAG.getNode(ExtOpc, dl, PVT, Op);
+ return DAG.getNode(ExtOpc, DL, PVT, Op);
}
}
if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
return SDValue();
- return DAG.getNode(ISD::ANY_EXTEND, dl, PVT, Op);
+ return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
}
SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
return SDValue();
EVT OldVT = Op.getValueType();
- SDLoc dl(Op);
+ SDLoc DL(Op);
bool Replace = false;
SDValue NewOp = PromoteOperand(Op, PVT, Replace);
if (!NewOp.getNode())
@@ -1010,13 +1133,13 @@ SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
if (Replace)
ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
- return DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, NewOp.getValueType(), NewOp,
+ return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
DAG.getValueType(OldVT));
}
SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
EVT OldVT = Op.getValueType();
- SDLoc dl(Op);
+ SDLoc DL(Op);
bool Replace = false;
SDValue NewOp = PromoteOperand(Op, PVT, Replace);
if (!NewOp.getNode())
@@ -1025,7 +1148,7 @@ SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
if (Replace)
ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
- return DAG.getZeroExtendInReg(NewOp, dl, OldVT);
+ return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
}
/// Promote the specified integer binary operation if the target indicates it is
@@ -1051,37 +1174,44 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
assert(PVT != VT && "Don't know what type to promote to!");
+ LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
+
bool Replace0 = false;
SDValue N0 = Op.getOperand(0);
SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
- if (!NN0.getNode())
- return SDValue();
bool Replace1 = false;
SDValue N1 = Op.getOperand(1);
- SDValue NN1;
- if (N0 == N1)
- NN1 = NN0;
- else {
- NN1 = PromoteOperand(N1, PVT, Replace1);
- if (!NN1.getNode())
- return SDValue();
- }
+ SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
+ SDLoc DL(Op);
- AddToWorklist(NN0.getNode());
- if (NN1.getNode())
- AddToWorklist(NN1.getNode());
+ SDValue RV =
+ DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
+
+ // We are always replacing N0/N1's use in N and only need
+ // additional replacements if there are additional uses.
+ Replace0 &= !N0->hasOneUse();
+ Replace1 &= (N0 != N1) && !N1->hasOneUse();
- if (Replace0)
+ // Combine Op here so it is preserved past replacements.
+ CombineTo(Op.getNode(), RV);
+
+ // If operands have a use ordering, make sure we deal with
+ // predecessor first.
+ if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) {
+ std::swap(N0, N1);
+ std::swap(NN0, NN1);
+ }
+
+ if (Replace0) {
+ AddToWorklist(NN0.getNode());
ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
- if (Replace1)
+ }
+ if (Replace1) {
+ AddToWorklist(NN1.getNode());
ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
-
- DEBUG(dbgs() << "\nPromoting ";
- Op.getNode()->dump(&DAG));
- SDLoc dl(Op);
- return DAG.getNode(ISD::TRUNCATE, dl, VT,
- DAG.getNode(Opc, dl, PVT, NN0, NN1));
+ }
+ return Op;
}
return SDValue();
}
@@ -1109,26 +1239,32 @@ SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
assert(PVT != VT && "Don't know what type to promote to!");
+ LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
+
bool Replace = false;
SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
if (Opc == ISD::SRA)
- N0 = SExtPromoteOperand(Op.getOperand(0), PVT);
+ N0 = SExtPromoteOperand(N0, PVT);
else if (Opc == ISD::SRL)
- N0 = ZExtPromoteOperand(Op.getOperand(0), PVT);
+ N0 = ZExtPromoteOperand(N0, PVT);
else
N0 = PromoteOperand(N0, PVT, Replace);
+
if (!N0.getNode())
return SDValue();
+ SDLoc DL(Op);
+ SDValue RV =
+ DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
+
AddToWorklist(N0.getNode());
if (Replace)
ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
- DEBUG(dbgs() << "\nPromoting ";
- Op.getNode()->dump(&DAG));
- SDLoc dl(Op);
- return DAG.getNode(ISD::TRUNCATE, dl, VT,
- DAG.getNode(Opc, dl, PVT, N0, Op.getOperand(1)));
+ // Deal with Op being deleted.
+ if (Op && Op.getOpcode() != ISD::DELETED_NODE)
+ return RV;
}
return SDValue();
}
@@ -1155,8 +1291,7 @@ SDValue DAGCombiner::PromoteExtend(SDValue Op) {
// fold (aext (aext x)) -> (aext x)
// fold (aext (zext x)) -> (zext x)
// fold (aext (sext x)) -> (sext x)
- DEBUG(dbgs() << "\nPromoting ";
- Op.getNode()->dump(&DAG));
+ LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
}
return SDValue();
@@ -1166,6 +1301,9 @@ bool DAGCombiner::PromoteLoad(SDValue Op) {
if (!LegalOperations)
return false;
+ if (!ISD::isUNINDEXEDLoad(Op.getNode()))
+ return false;
+
EVT VT = Op.getValueType();
if (VT.isVector() || !VT.isInteger())
return false;
@@ -1182,24 +1320,19 @@ bool DAGCombiner::PromoteLoad(SDValue Op) {
if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
assert(PVT != VT && "Don't know what type to promote to!");
- SDLoc dl(Op);
+ SDLoc DL(Op);
SDNode *N = Op.getNode();
LoadSDNode *LD = cast<LoadSDNode>(N);
EVT MemVT = LD->getMemoryVT();
- ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD)
- ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, PVT, MemVT) ? ISD::ZEXTLOAD
- : ISD::EXTLOAD)
- : LD->getExtensionType();
- SDValue NewLD = DAG.getExtLoad(ExtType, dl, PVT,
+ ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
+ : LD->getExtensionType();
+ SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
LD->getChain(), LD->getBasePtr(),
MemVT, LD->getMemOperand());
- SDValue Result = DAG.getNode(ISD::TRUNCATE, dl, VT, NewLD);
+ SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
- DEBUG(dbgs() << "\nPromoting ";
- N->dump(&DAG);
- dbgs() << "\nTo: ";
- Result.getNode()->dump(&DAG);
- dbgs() << '\n');
+ LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
+ Result.getNode()->dump(&DAG); dbgs() << '\n');
WorklistRemover DeadNodes(*this);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
@@ -1210,7 +1343,7 @@ bool DAGCombiner::PromoteLoad(SDValue Op) {
return false;
}
-/// \brief Recursively delete a node which has no uses and any operands for
+/// Recursively delete a node which has no uses and any operands for
/// which it is the only use.
///
/// Note that this both deletes the nodes and removes them from the worklist.
@@ -1259,8 +1392,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
// changes of the root.
HandleSDNode Dummy(DAG.getRoot());
- // while the worklist isn't empty, find a node and
- // try and combine it.
+ // While the worklist isn't empty, find a node and try to combine it.
while (!WorklistMap.empty()) {
SDNode *N;
// The Worklist holds the SDNodes in order, but it may contain null entries.
@@ -1295,7 +1427,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
continue;
}
- DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
+ LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
// Add any operands of the new node which have not yet been combined to the
// worklist as well. Because the worklist uniques things already, this
@@ -1320,21 +1452,17 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
continue;
assert(N->getOpcode() != ISD::DELETED_NODE &&
- RV.getNode()->getOpcode() != ISD::DELETED_NODE &&
+ RV.getOpcode() != ISD::DELETED_NODE &&
"Node was deleted but visit returned new node!");
- DEBUG(dbgs() << " ... into: ";
- RV.getNode()->dump(&DAG));
+ LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG));
- // Transfer debug value.
- DAG.TransferDbgValues(SDValue(N, 0), RV);
if (N->getNumValues() == RV.getNode()->getNumValues())
DAG.ReplaceAllUsesWith(N, RV.getNode());
else {
assert(N->getValueType(0) == RV.getValueType() &&
N->getNumValues() == 1 && "Type mismatch");
- SDValue OpV = RV;
- DAG.ReplaceAllUsesWith(N, &OpV);
+ DAG.ReplaceAllUsesWith(N, &RV);
}
// Push the new node and any users onto the worklist
@@ -1360,10 +1488,18 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
case ISD::ADD: return visitADD(N);
case ISD::SUB: return visitSUB(N);
+ case ISD::SADDSAT:
+ case ISD::UADDSAT: return visitADDSAT(N);
+ case ISD::SSUBSAT:
+ case ISD::USUBSAT: return visitSUBSAT(N);
case ISD::ADDC: return visitADDC(N);
+ case ISD::UADDO: return visitUADDO(N);
case ISD::SUBC: return visitSUBC(N);
+ case ISD::USUBO: return visitUSUBO(N);
case ISD::ADDE: return visitADDE(N);
+ case ISD::ADDCARRY: return visitADDCARRY(N);
case ISD::SUBE: return visitSUBE(N);
+ case ISD::SUBCARRY: return visitSUBCARRY(N);
case ISD::MUL: return visitMUL(N);
case ISD::SDIV: return visitSDIV(N);
case ISD::UDIV: return visitUDIV(N);
@@ -1387,7 +1523,11 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::SRL: return visitSRL(N);
case ISD::ROTR:
case ISD::ROTL: return visitRotate(N);
+ case ISD::FSHL:
+ case ISD::FSHR: return visitFunnelShift(N);
+ case ISD::ABS: return visitABS(N);
case ISD::BSWAP: return visitBSWAP(N);
+ case ISD::BITREVERSE: return visitBITREVERSE(N);
case ISD::CTLZ: return visitCTLZ(N);
case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
case ISD::CTTZ: return visitCTTZ(N);
@@ -1397,12 +1537,15 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::VSELECT: return visitVSELECT(N);
case ISD::SELECT_CC: return visitSELECT_CC(N);
case ISD::SETCC: return visitSETCC(N);
- case ISD::SETCCE: return visitSETCCE(N);
+ case ISD::SETCCCARRY: return visitSETCCCARRY(N);
case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
+ case ISD::AssertSext:
+ case ISD::AssertZext: return visitAssertExt(N);
case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N);
+ case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
case ISD::BITCAST: return visitBITCAST(N);
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
@@ -1414,6 +1557,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::FREM: return visitFREM(N);
case ISD::FSQRT: return visitFSQRT(N);
case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
+ case ISD::FPOW: return visitFPOW(N);
case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
@@ -1426,6 +1570,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::FFLOOR: return visitFFLOOR(N);
case ISD::FMINNUM: return visitFMINNUM(N);
case ISD::FMAXNUM: return visitFMAXNUM(N);
+ case ISD::FMINIMUM: return visitFMINIMUM(N);
+ case ISD::FMAXIMUM: return visitFMAXIMUM(N);
case ISD::FCEIL: return visitFCEIL(N);
case ISD::FTRUNC: return visitFTRUNC(N);
case ISD::BRCOND: return visitBRCOND(N);
@@ -1498,15 +1644,15 @@ SDValue DAGCombiner::combine(SDNode *N) {
}
}
- // If N is a commutative binary node, try commuting it to enable more
- // sdisel CSE.
- if (!RV.getNode() && SelectionDAG::isCommutativeBinOp(N->getOpcode()) &&
+ // If N is a commutative binary node, try eliminate it if the commuted
+ // version is already present in the DAG.
+ if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
N->getNumValues() == 1) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
// Constant operands are canonicalized to RHS.
- if (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1)) {
+ if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
SDValue Ops[] = {N1, N0};
SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
N->getFlags());
@@ -1543,8 +1689,12 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
return N->getOperand(1);
}
+ // Don't simplify token factors if optnone.
+ if (OptLevel == CodeGenOpt::None)
+ return SDValue();
+
SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
- SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
+ SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
SmallPtrSet<SDNode*, 16> SeenOps;
bool Changed = false; // If we should replace this token factor.
@@ -1558,7 +1708,6 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
// Check each of the operands.
for (const SDValue &Op : TF->op_values()) {
-
switch (Op.getOpcode()) {
case ISD::EntryToken:
// Entry tokens don't need to be added to the list. They are
@@ -1567,8 +1716,7 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
break;
case ISD::TokenFactor:
- if (Op.hasOneUse() &&
- std::find(TFs.begin(), TFs.end(), Op.getNode()) == TFs.end()) {
+ if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
// Queue up for processing.
TFs.push_back(Op.getNode());
// Clean up in case the token factor is removed.
@@ -1576,7 +1724,7 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
Changed = true;
break;
}
- // Fall thru
+ LLVM_FALLTHROUGH;
default:
// Only add if it isn't already in the list.
@@ -1589,26 +1737,108 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
}
}
- SDValue Result;
+ // Remove Nodes that are chained to another node in the list. Do so
+ // by walking up chains breath-first stopping when we've seen
+ // another operand. In general we must climb to the EntryNode, but we can exit
+ // early if we find all remaining work is associated with just one operand as
+ // no further pruning is possible.
+
+ // List of nodes to search through and original Ops from which they originate.
+ SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
+ SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
+ SmallPtrSet<SDNode *, 16> SeenChains;
+ bool DidPruneOps = false;
+
+ unsigned NumLeftToConsider = 0;
+ for (const SDValue &Op : Ops) {
+ Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
+ OpWorkCount.push_back(1);
+ }
+
+ auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
+ // If this is an Op, we can remove the op from the list. Remark any
+ // search associated with it as from the current OpNumber.
+ if (SeenOps.count(Op) != 0) {
+ Changed = true;
+ DidPruneOps = true;
+ unsigned OrigOpNumber = 0;
+ while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
+ OrigOpNumber++;
+ assert((OrigOpNumber != Ops.size()) &&
+ "expected to find TokenFactor Operand");
+ // Re-mark worklist from OrigOpNumber to OpNumber
+ for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
+ if (Worklist[i].second == OrigOpNumber) {
+ Worklist[i].second = OpNumber;
+ }
+ }
+ OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
+ OpWorkCount[OrigOpNumber] = 0;
+ NumLeftToConsider--;
+ }
+ // Add if it's a new chain
+ if (SeenChains.insert(Op).second) {
+ OpWorkCount[OpNumber]++;
+ Worklist.push_back(std::make_pair(Op, OpNumber));
+ }
+ };
+
+ for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
+ // We need at least be consider at least 2 Ops to prune.
+ if (NumLeftToConsider <= 1)
+ break;
+ auto CurNode = Worklist[i].first;
+ auto CurOpNumber = Worklist[i].second;
+ assert((OpWorkCount[CurOpNumber] > 0) &&
+ "Node should not appear in worklist");
+ switch (CurNode->getOpcode()) {
+ case ISD::EntryToken:
+ // Hitting EntryToken is the only way for the search to terminate without
+ // hitting
+ // another operand's search. Prevent us from marking this operand
+ // considered.
+ NumLeftToConsider++;
+ break;
+ case ISD::TokenFactor:
+ for (const SDValue &Op : CurNode->op_values())
+ AddToWorklist(i, Op.getNode(), CurOpNumber);
+ break;
+ case ISD::CopyFromReg:
+ case ISD::CopyToReg:
+ AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
+ break;
+ default:
+ if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
+ AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
+ break;
+ }
+ OpWorkCount[CurOpNumber]--;
+ if (OpWorkCount[CurOpNumber] == 0)
+ NumLeftToConsider--;
+ }
// If we've changed things around then replace token factor.
if (Changed) {
+ SDValue Result;
if (Ops.empty()) {
// The entry token is the only possible outcome.
Result = DAG.getEntryNode();
} else {
- // New and improved token factor.
- Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Ops);
+ if (DidPruneOps) {
+ SmallVector<SDValue, 8> PrunedOps;
+ //
+ for (const SDValue &Op : Ops) {
+ if (SeenChains.count(Op.getNode()) == 0)
+ PrunedOps.push_back(Op);
+ }
+ Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, PrunedOps);
+ } else {
+ Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Ops);
+ }
}
-
- // Add users to worklist if AA is enabled, since it may introduce
- // a lot of new chained token factors while removing memory deps.
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
- return CombineTo(N, Result, UseAA /*add to worklist*/);
+ return Result;
}
-
- return Result;
+ return SDValue();
}
/// MERGE_VALUES can always be eliminated.
@@ -1621,24 +1851,179 @@ SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
// can be tried again once they have new operands.
AddUsersToWorklist(N);
do {
+ // Do as a single replacement to avoid rewalking use lists.
+ SmallVector<SDValue, 8> Ops;
for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
- DAG.ReplaceAllUsesOfValueWith(SDValue(N, i), N->getOperand(i));
+ Ops.push_back(N->getOperand(i));
+ DAG.ReplaceAllUsesWith(N, Ops.data());
} while (!N->use_empty());
deleteAndRecombine(N);
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
-/// If \p N is a ContantSDNode with isOpaque() == false return it casted to a
-/// ContantSDNode pointer else nullptr.
+/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
+/// ConstantSDNode pointer else nullptr.
static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
}
+SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
+ assert(ISD::isBinaryOp(BO) && "Unexpected binary operator");
+
+ // Don't do this unless the old select is going away. We want to eliminate the
+ // binary operator, not replace a binop with a select.
+ // TODO: Handle ISD::SELECT_CC.
+ unsigned SelOpNo = 0;
+ SDValue Sel = BO->getOperand(0);
+ if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
+ SelOpNo = 1;
+ Sel = BO->getOperand(1);
+ }
+
+ if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
+ return SDValue();
+
+ SDValue CT = Sel.getOperand(1);
+ if (!isConstantOrConstantVector(CT, true) &&
+ !isConstantFPBuildVectorOrConstantFP(CT))
+ return SDValue();
+
+ SDValue CF = Sel.getOperand(2);
+ if (!isConstantOrConstantVector(CF, true) &&
+ !isConstantFPBuildVectorOrConstantFP(CF))
+ return SDValue();
+
+ // Bail out if any constants are opaque because we can't constant fold those.
+ // The exception is "and" and "or" with either 0 or -1 in which case we can
+ // propagate non constant operands into select. I.e.:
+ // and (select Cond, 0, -1), X --> select Cond, 0, X
+ // or X, (select Cond, -1, 0) --> select Cond, -1, X
+ auto BinOpcode = BO->getOpcode();
+ bool CanFoldNonConst =
+ (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
+ (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
+ (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
+
+ SDValue CBO = BO->getOperand(SelOpNo ^ 1);
+ if (!CanFoldNonConst &&
+ !isConstantOrConstantVector(CBO, true) &&
+ !isConstantFPBuildVectorOrConstantFP(CBO))
+ return SDValue();
+
+ EVT VT = Sel.getValueType();
+
+ // In case of shift value and shift amount may have different VT. For instance
+ // on x86 shift amount is i8 regardles of LHS type. Bail out if we have
+ // swapped operands and value types do not match. NB: x86 is fine if operands
+ // are not swapped with shift amount VT being not bigger than shifted value.
+ // TODO: that is possible to check for a shift operation, correct VTs and
+ // still perform optimization on x86 if needed.
+ if (SelOpNo && VT != CBO.getValueType())
+ return SDValue();
+
+ // We have a select-of-constants followed by a binary operator with a
+ // constant. Eliminate the binop by pulling the constant math into the select.
+ // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
+ SDLoc DL(Sel);
+ SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
+ : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
+ if (!CanFoldNonConst && !NewCT.isUndef() &&
+ !isConstantOrConstantVector(NewCT, true) &&
+ !isConstantFPBuildVectorOrConstantFP(NewCT))
+ return SDValue();
+
+ SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
+ : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
+ if (!CanFoldNonConst && !NewCF.isUndef() &&
+ !isConstantOrConstantVector(NewCF, true) &&
+ !isConstantFPBuildVectorOrConstantFP(NewCF))
+ return SDValue();
+
+ return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
+}
+
+static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
+ assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
+ "Expecting add or sub");
+
+ // Match a constant operand and a zext operand for the math instruction:
+ // add Z, C
+ // sub C, Z
+ bool IsAdd = N->getOpcode() == ISD::ADD;
+ SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
+ SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
+ auto *CN = dyn_cast<ConstantSDNode>(C);
+ if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+
+ // Match the zext operand as a setcc of a boolean.
+ if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
+ Z.getOperand(0).getValueType() != MVT::i1)
+ return SDValue();
+
+ // Match the compare as: setcc (X & 1), 0, eq.
+ SDValue SetCC = Z.getOperand(0);
+ ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
+ if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
+ SetCC.getOperand(0).getOpcode() != ISD::AND ||
+ !isOneConstant(SetCC.getOperand(0).getOperand(1)))
+ return SDValue();
+
+ // We are adding/subtracting a constant and an inverted low bit. Turn that
+ // into a subtract/add of the low bit with incremented/decremented constant:
+ // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
+ // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
+ EVT VT = C.getValueType();
+ SDLoc DL(N);
+ SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
+ SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
+ DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
+ return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
+}
+
+/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
+/// a shift and add with a different constant.
+static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
+ assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
+ "Expecting add or sub");
+
+ // We need a constant operand for the add/sub, and the other operand is a
+ // logical shift right: add (srl), C or sub C, (srl).
+ bool IsAdd = N->getOpcode() == ISD::ADD;
+ SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
+ SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
+ ConstantSDNode *C = isConstOrConstSplat(ConstantOp);
+ if (!C || ShiftOp.getOpcode() != ISD::SRL)
+ return SDValue();
+
+ // The shift must be of a 'not' value.
+ SDValue Not = ShiftOp.getOperand(0);
+ if (!Not.hasOneUse() || !isBitwiseNot(Not))
+ return SDValue();
+
+ // The shift must be moving the sign bit to the least-significant-bit.
+ EVT VT = ShiftOp.getValueType();
+ SDValue ShAmt = ShiftOp.getOperand(1);
+ ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
+ if (!ShAmtC || ShAmtC->getZExtValue() != VT.getScalarSizeInBits() - 1)
+ return SDValue();
+
+ // Eliminate the 'not' by adjusting the shift and add/sub constant:
+ // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
+ // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
+ SDLoc DL(N);
+ auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL;
+ SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt);
+ APInt NewC = IsAdd ? C->getAPIntValue() + 1 : C->getAPIntValue() - 1;
+ return DAG.getNode(ISD::ADD, DL, VT, NewShift, DAG.getConstant(NewC, DL, VT));
+}
+
SDValue DAGCombiner::visitADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
+ SDLoc DL(N);
// fold vector ops
if (VT.isVector()) {
@@ -1653,69 +2038,102 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
}
// fold (add x, undef) -> undef
- if (N0.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef())
return N0;
- if (N1.getOpcode() == ISD::UNDEF)
+
+ if (N1.isUndef())
return N1;
- // fold (add c1, c2) -> c1+c2
- ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
- ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
- if (N0C && N1C)
- return DAG.FoldConstantArithmetic(ISD::ADD, SDLoc(N), VT, N0C, N1C);
- // canonicalize constant to RHS
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
- !isConstantIntBuildVectorOrConstantInt(N1))
- return DAG.getNode(ISD::ADD, SDLoc(N), VT, N1, N0);
+
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
+ // canonicalize constant to RHS
+ if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
+ return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
+ // fold (add c1, c2) -> c1+c2
+ return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N0.getNode(),
+ N1.getNode());
+ }
+
// fold (add x, 0) -> x
if (isNullConstant(N1))
return N0;
- // fold (add Sym, c) -> Sym+c
- if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
- if (!LegalOperations && TLI.isOffsetFoldingLegal(GA) && N1C &&
- GA->getOpcode() == ISD::GlobalAddress)
- return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
- GA->getOffset() +
- (uint64_t)N1C->getSExtValue());
- // fold ((c1-A)+c2) -> (c1+c2)-A
- if (N1C && N0.getOpcode() == ISD::SUB)
- if (ConstantSDNode *N0C = getAsNonOpaqueConstant(N0.getOperand(0))) {
- SDLoc DL(N);
+
+ if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) {
+ // fold ((c1-A)+c2) -> (c1+c2)-A
+ if (N0.getOpcode() == ISD::SUB &&
+ isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
+ // FIXME: Adding 2 constants should be handled by FoldConstantArithmetic.
return DAG.getNode(ISD::SUB, DL, VT,
- DAG.getConstant(N1C->getAPIntValue()+
- N0C->getAPIntValue(), DL, VT),
+ DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
N0.getOperand(1));
}
+
+ // add (sext i1 X), 1 -> zext (not i1 X)
+ // We don't transform this pattern:
+ // add (zext i1 X), -1 -> sext (not i1 X)
+ // because most (?) targets generate better code for the zext form.
+ if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
+ isOneOrOneSplat(N1)) {
+ SDValue X = N0.getOperand(0);
+ if ((!LegalOperations ||
+ (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
+ TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
+ X.getScalarValueSizeInBits() == 1) {
+ SDValue Not = DAG.getNOT(DL, X, X.getValueType());
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
+ }
+ }
+
+ // Undo the add -> or combine to merge constant offsets from a frame index.
+ if (N0.getOpcode() == ISD::OR &&
+ isa<FrameIndexSDNode>(N0.getOperand(0)) &&
+ isa<ConstantSDNode>(N0.getOperand(1)) &&
+ DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) {
+ SDValue Add0 = DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(1));
+ return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0);
+ }
+ }
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// reassociate add
- if (SDValue RADD = ReassociateOps(ISD::ADD, SDLoc(N), N0, N1))
+ if (SDValue RADD = ReassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
return RADD;
+
// fold ((0-A) + B) -> B-A
- if (N0.getOpcode() == ISD::SUB && isNullConstant(N0.getOperand(0)))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1, N0.getOperand(1));
+ if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
+ return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
+
// fold (A + (0-B)) -> A-B
- if (N1.getOpcode() == ISD::SUB && isNullConstant(N1.getOperand(0)))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, N1.getOperand(1));
+ if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
+ return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
+
// fold (A+(B-A)) -> B
if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
return N1.getOperand(0);
+
// fold ((B-A)+A) -> B
if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
return N0.getOperand(0);
+
// fold (A+(B-(A+C))) to (B-C)
if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
N0 == N1.getOperand(1).getOperand(0))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1.getOperand(0),
+ return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
N1.getOperand(1).getOperand(1));
+
// fold (A+(B-(C+A))) to (B-C)
if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
N0 == N1.getOperand(1).getOperand(1))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1.getOperand(0),
+ return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
N1.getOperand(1).getOperand(0));
+
// fold (A+((B-A)+or-C)) to (B+or-C)
if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
N1.getOperand(0).getOpcode() == ISD::SUB &&
N0 == N1.getOperand(0).getOperand(1))
- return DAG.getNode(N1.getOpcode(), SDLoc(N), VT,
- N1.getOperand(0).getOperand(0), N1.getOperand(1));
+ return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
+ N1.getOperand(1));
// fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) {
@@ -1724,52 +2142,148 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
SDValue N10 = N1.getOperand(0);
SDValue N11 = N1.getOperand(1);
- if (isa<ConstantSDNode>(N00) || isa<ConstantSDNode>(N10))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT,
+ if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
+ return DAG.getNode(ISD::SUB, DL, VT,
DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
}
- if (!VT.isVector() && SimplifyDemandedBits(SDValue(N, 0)))
+ if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
+ return V;
+
+ if (SDValue V = foldAddSubOfSignBit(N, DAG))
+ return V;
+
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
// fold (a+b) -> (a|b) iff a and b share no bits.
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
- VT.isInteger() && !VT.isVector() && DAG.haveNoCommonBitsSet(N0, N1))
- return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1);
+ DAG.haveNoCommonBitsSet(N0, N1))
+ return DAG.getNode(ISD::OR, DL, VT, N0, N1);
+
+ // fold (add (xor a, -1), 1) -> (sub 0, a)
+ if (isBitwiseNot(N0) && isOneOrOneSplat(N1))
+ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
+ N0.getOperand(0));
+
+ if (SDValue Combined = visitADDLike(N0, N1, N))
+ return Combined;
+
+ if (SDValue Combined = visitADDLike(N1, N0, N))
+ return Combined;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitADDSAT(SDNode *N) {
+ unsigned Opcode = N->getOpcode();
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ SDLoc DL(N);
+
+ // fold vector ops
+ if (VT.isVector()) {
+ // TODO SimplifyVBinOp
+
+ // fold (add_sat x, 0) -> x, vector edition
+ if (ISD::isBuildVectorAllZeros(N1.getNode()))
+ return N0;
+ if (ISD::isBuildVectorAllZeros(N0.getNode()))
+ return N1;
+ }
+
+ // fold (add_sat x, undef) -> -1
+ if (N0.isUndef() || N1.isUndef())
+ return DAG.getAllOnesConstant(DL, VT);
+
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
+ // canonicalize constant to RHS
+ if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
+ return DAG.getNode(Opcode, DL, VT, N1, N0);
+ // fold (add_sat c1, c2) -> c3
+ return DAG.FoldConstantArithmetic(Opcode, DL, VT, N0.getNode(),
+ N1.getNode());
+ }
+
+ // fold (add_sat x, 0) -> x
+ if (isNullConstant(N1))
+ return N0;
+
+ // If it cannot overflow, transform into an add.
+ if (Opcode == ISD::UADDSAT)
+ if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
+ return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
+
+ return SDValue();
+}
+
+static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
+ bool Masked = false;
+
+ // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
+ while (true) {
+ if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
+ V = V.getOperand(0);
+ continue;
+ }
+
+ if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
+ Masked = true;
+ V = V.getOperand(0);
+ continue;
+ }
+
+ break;
+ }
+
+ // If this is not a carry, return.
+ if (V.getResNo() != 1)
+ return SDValue();
+
+ if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
+ V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
+ return SDValue();
+
+ // If the result is masked, then no matter what kind of bool it is we can
+ // return. If it isn't, then we need to make sure the bool type is either 0 or
+ // 1 and not other values.
+ if (Masked ||
+ TLI.getBooleanContents(V.getValueType()) ==
+ TargetLoweringBase::ZeroOrOneBooleanContent)
+ return V;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference) {
+ EVT VT = N0.getValueType();
+ SDLoc DL(LocReference);
// fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
- isNullConstant(N1.getOperand(0).getOperand(0)))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT, N0,
- DAG.getNode(ISD::SHL, SDLoc(N), VT,
+ isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
+ return DAG.getNode(ISD::SUB, DL, VT, N0,
+ DAG.getNode(ISD::SHL, DL, VT,
N1.getOperand(0).getOperand(1),
N1.getOperand(1)));
- if (N0.getOpcode() == ISD::SHL && N0.getOperand(0).getOpcode() == ISD::SUB &&
- isNullConstant(N0.getOperand(0).getOperand(0)))
- return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1,
- DAG.getNode(ISD::SHL, SDLoc(N), VT,
- N0.getOperand(0).getOperand(1),
- N0.getOperand(1)));
if (N1.getOpcode() == ISD::AND) {
SDValue AndOp0 = N1.getOperand(0);
unsigned NumSignBits = DAG.ComputeNumSignBits(AndOp0);
- unsigned DestBits = VT.getScalarType().getSizeInBits();
+ unsigned DestBits = VT.getScalarSizeInBits();
// (add z, (and (sbbl x, x), 1)) -> (sub z, (sbbl x, x))
// and similar xforms where the inner op is either ~0 or 0.
- if (NumSignBits == DestBits && isOneConstant(N1->getOperand(1))) {
- SDLoc DL(N);
- return DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), AndOp0);
- }
+ if (NumSignBits == DestBits && isOneOrOneSplat(N1->getOperand(1)))
+ return DAG.getNode(ISD::SUB, DL, VT, N0, AndOp0);
}
// add (sext i1), X -> sub X, (zext i1)
if (N0.getOpcode() == ISD::SIGN_EXTEND &&
N0.getOperand(0).getValueType() == MVT::i1 &&
!TLI.isOperationLegal(ISD::SIGN_EXTEND, MVT::i1)) {
- SDLoc DL(N);
SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
}
@@ -1778,13 +2292,25 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
if (TN->getVT() == MVT::i1) {
- SDLoc DL(N);
SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
DAG.getConstant(1, DL, VT));
return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
}
}
+ // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
+ if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
+ N1.getResNo() == 0)
+ return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
+ N0, N1.getOperand(0), N1.getOperand(2));
+
+ // (add X, Carry) -> (addcarry X, 0, Carry)
+ if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
+ if (SDValue Carry = getAsCarry(TLI, N1))
+ return DAG.getNode(ISD::ADDCARRY, DL,
+ DAG.getVTList(VT, Carry.getValueType()), N0,
+ DAG.getConstant(0, DL, VT), Carry);
+
return SDValue();
}
@@ -1792,40 +2318,131 @@ SDValue DAGCombiner::visitADDC(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
+ SDLoc DL(N);
// If the flag result is dead, turn this into an ADD.
if (!N->hasAnyUseOfValue(1))
- return CombineTo(N, DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N1),
- DAG.getNode(ISD::CARRY_FALSE,
- SDLoc(N), MVT::Glue));
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
// canonicalize constant to RHS.
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
if (N0C && !N1C)
- return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N1, N0);
+ return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
// fold (addc x, 0) -> x + no carry out
if (isNullConstant(N1))
return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
- SDLoc(N), MVT::Glue));
+ DL, MVT::Glue));
- // fold (addc a, b) -> (or a, b), CARRY_FALSE iff a and b share no bits.
- APInt LHSZero, LHSOne;
- APInt RHSZero, RHSOne;
- DAG.computeKnownBits(N0, LHSZero, LHSOne);
+ // If it cannot overflow, transform into an add.
+ if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
- if (LHSZero.getBoolValue()) {
- DAG.computeKnownBits(N1, RHSZero, RHSOne);
+ return SDValue();
+}
- // If all possibly-set bits on the LHS are clear on the RHS, return an OR.
- // If all possibly-set bits on the RHS are clear on the LHS, return an OR.
- if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero)
- return CombineTo(N, DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1),
- DAG.getNode(ISD::CARRY_FALSE,
- SDLoc(N), MVT::Glue));
+static SDValue flipBoolean(SDValue V, const SDLoc &DL, EVT VT,
+ SelectionDAG &DAG, const TargetLowering &TLI) {
+ SDValue Cst;
+ switch (TLI.getBooleanContents(VT)) {
+ case TargetLowering::ZeroOrOneBooleanContent:
+ case TargetLowering::UndefinedBooleanContent:
+ Cst = DAG.getConstant(1, DL, VT);
+ break;
+ case TargetLowering::ZeroOrNegativeOneBooleanContent:
+ Cst = DAG.getConstant(-1, DL, VT);
+ break;
}
+ return DAG.getNode(ISD::XOR, DL, VT, V, Cst);
+}
+
+static bool isBooleanFlip(SDValue V, EVT VT, const TargetLowering &TLI) {
+ if (V.getOpcode() != ISD::XOR) return false;
+ ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V.getOperand(1));
+ if (!Const) return false;
+
+ switch(TLI.getBooleanContents(VT)) {
+ case TargetLowering::ZeroOrOneBooleanContent:
+ return Const->isOne();
+ case TargetLowering::ZeroOrNegativeOneBooleanContent:
+ return Const->isAllOnesValue();
+ case TargetLowering::UndefinedBooleanContent:
+ return (Const->getAPIntValue() & 0x01) == 1;
+ }
+ llvm_unreachable("Unsupported boolean content");
+}
+
+SDValue DAGCombiner::visitUADDO(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ if (VT.isVector())
+ return SDValue();
+
+ EVT CarryVT = N->getValueType(1);
+ SDLoc DL(N);
+
+ // If the flag result is dead, turn this into an ADD.
+ if (!N->hasAnyUseOfValue(1))
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getUNDEF(CarryVT));
+
+ // canonicalize constant to RHS.
+ ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
+ ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ if (N0C && !N1C)
+ return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N1, N0);
+
+ // fold (uaddo x, 0) -> x + no carry out
+ if (isNullConstant(N1))
+ return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
+
+ // If it cannot overflow, transform into an add.
+ if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getConstant(0, DL, CarryVT));
+
+ // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
+ if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
+ SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
+ DAG.getConstant(0, DL, VT),
+ N0.getOperand(0));
+ return CombineTo(N, Sub,
+ flipBoolean(Sub.getValue(1), DL, CarryVT, DAG, TLI));
+ }
+
+ if (SDValue Combined = visitUADDOLike(N0, N1, N))
+ return Combined;
+
+ if (SDValue Combined = visitUADDOLike(N1, N0, N))
+ return Combined;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
+ auto VT = N0.getValueType();
+
+ // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
+ // If Y + 1 cannot overflow.
+ if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
+ SDValue Y = N1.getOperand(0);
+ SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
+ if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
+ N1.getOperand(2));
+ }
+
+ // (uaddo X, Carry) -> (addcarry X, 0, Carry)
+ if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
+ if (SDValue Carry = getAsCarry(TLI, N1))
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
+ DAG.getConstant(0, SDLoc(N), VT), Carry);
+
return SDValue();
}
@@ -1848,11 +2465,108 @@ SDValue DAGCombiner::visitADDE(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue CarryIn = N->getOperand(2);
+ SDLoc DL(N);
+
+ // canonicalize constant to RHS
+ ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
+ ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ if (N0C && !N1C)
+ return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
+
+ // fold (addcarry x, y, false) -> (uaddo x, y)
+ if (isNullConstant(CarryIn)) {
+ if (!LegalOperations ||
+ TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
+ return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
+ }
+
+ EVT CarryVT = CarryIn.getValueType();
+
+ // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
+ if (isNullConstant(N0) && isNullConstant(N1)) {
+ EVT VT = N0.getValueType();
+ SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
+ AddToWorklist(CarryExt.getNode());
+ return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
+ DAG.getConstant(1, DL, VT)),
+ DAG.getConstant(0, DL, CarryVT));
+ }
+
+ // fold (addcarry (xor a, -1), 0, !b) -> (subcarry 0, a, b) and flip carry.
+ if (isBitwiseNot(N0) && isNullConstant(N1) &&
+ isBooleanFlip(CarryIn, CarryVT, TLI)) {
+ SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(),
+ DAG.getConstant(0, DL, N0.getValueType()),
+ N0.getOperand(0), CarryIn.getOperand(0));
+ return CombineTo(N, Sub,
+ flipBoolean(Sub.getValue(1), DL, CarryVT, DAG, TLI));
+ }
+
+ if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
+ return Combined;
+
+ if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
+ return Combined;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
+ SDNode *N) {
+ // Iff the flag result is dead:
+ // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
+ if ((N0.getOpcode() == ISD::ADD ||
+ (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0)) &&
+ isNullConstant(N1) && !N->hasAnyUseOfValue(1))
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
+ N0.getOperand(0), N0.getOperand(1), CarryIn);
+
+ /**
+ * When one of the addcarry argument is itself a carry, we may be facing
+ * a diamond carry propagation. In which case we try to transform the DAG
+ * to ensure linear carry propagation if that is possible.
+ *
+ * We are trying to get:
+ * (addcarry X, 0, (addcarry A, B, Z):Carry)
+ */
+ if (auto Y = getAsCarry(TLI, N1)) {
+ /**
+ * (uaddo A, B)
+ * / \
+ * Carry Sum
+ * | \
+ * | (addcarry *, 0, Z)
+ * | /
+ * \ Carry
+ * | /
+ * (addcarry X, *, *)
+ */
+ if (Y.getOpcode() == ISD::UADDO &&
+ CarryIn.getResNo() == 1 &&
+ CarryIn.getOpcode() == ISD::ADDCARRY &&
+ isNullConstant(CarryIn.getOperand(1)) &&
+ CarryIn.getOperand(0) == Y.getValue(0)) {
+ auto NewY = DAG.getNode(ISD::ADDCARRY, SDLoc(N), Y->getVTList(),
+ Y.getOperand(0), Y.getOperand(1),
+ CarryIn.getOperand(2));
+ AddToWorklist(NewY.getNode());
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
+ DAG.getConstant(0, SDLoc(N), N0.getValueType()),
+ NewY.getValue(1));
+ }
+ }
+
+ return SDValue();
+}
+
// Since it may not be valid to emit a fold to zero for vector initializers
// check if we can before folding.
-static SDValue tryFoldToZero(SDLoc DL, const TargetLowering &TLI, EVT VT,
- SelectionDAG &DAG,
- bool LegalOperations, bool LegalTypes) {
+static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
+ SelectionDAG &DAG, bool LegalOperations) {
if (!VT.isVector())
return DAG.getConstant(0, DL, VT);
if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
@@ -1864,6 +2578,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
+ SDLoc DL(N);
// fold vector ops
if (VT.isVector()) {
@@ -1878,66 +2593,155 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
// fold (sub x, x) -> 0
// FIXME: Refactor this and xor and other similar operations together.
if (N0 == N1)
- return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes);
- // fold (sub c1, c2) -> c1-c2
- ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
+ return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
+ // fold (sub c1, c2) -> c1-c2
+ return DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(),
+ N1.getNode());
+ }
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
- if (N0C && N1C)
- return DAG.FoldConstantArithmetic(ISD::SUB, SDLoc(N), VT, N0C, N1C);
+
// fold (sub x, c) -> (add x, -c)
if (N1C) {
- SDLoc DL(N);
return DAG.getNode(ISD::ADD, DL, VT, N0,
DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
}
+
+ if (isNullOrNullSplat(N0)) {
+ unsigned BitWidth = VT.getScalarSizeInBits();
+ // Right-shifting everything out but the sign bit followed by negation is
+ // the same as flipping arithmetic/logical shift type without the negation:
+ // -(X >>u 31) -> (X >>s 31)
+ // -(X >>s 31) -> (X >>u 31)
+ if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
+ ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
+ if (ShiftAmt && ShiftAmt->getZExtValue() == BitWidth - 1) {
+ auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
+ if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
+ return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
+ }
+ }
+
+ // 0 - X --> 0 if the sub is NUW.
+ if (N->getFlags().hasNoUnsignedWrap())
+ return N0;
+
+ if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
+ // N1 is either 0 or the minimum signed value. If the sub is NSW, then
+ // N1 must be 0 because negating the minimum signed value is undefined.
+ if (N->getFlags().hasNoSignedWrap())
+ return N0;
+
+ // 0 - X --> X if X is 0 or the minimum signed value.
+ return N1;
+ }
+ }
+
// Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
- if (isAllOnesConstant(N0))
- return DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0);
+ if (isAllOnesOrAllOnesSplat(N0))
+ return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
+
+ // fold (A - (0-B)) -> A+B
+ if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
+ return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
+
// fold A-(A-B) -> B
if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
return N1.getOperand(1);
+
// fold (A+B)-A -> B
if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
return N0.getOperand(1);
+
// fold (A+B)-B -> A
if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
return N0.getOperand(0);
+
// fold C2-(A+C1) -> (C2-C1)-A
- ConstantSDNode *N1C1 = N1.getOpcode() != ISD::ADD ? nullptr :
- dyn_cast<ConstantSDNode>(N1.getOperand(1).getNode());
- if (N1.getOpcode() == ISD::ADD && N0C && N1C1) {
- SDLoc DL(N);
- SDValue NewC = DAG.getConstant(N0C->getAPIntValue() - N1C1->getAPIntValue(),
- DL, VT);
- return DAG.getNode(ISD::SUB, DL, VT, NewC,
- N1.getOperand(0));
+ if (N1.getOpcode() == ISD::ADD) {
+ SDValue N11 = N1.getOperand(1);
+ if (isConstantOrConstantVector(N0, /* NoOpaques */ true) &&
+ isConstantOrConstantVector(N11, /* NoOpaques */ true)) {
+ SDValue NewC = DAG.getNode(ISD::SUB, DL, VT, N0, N11);
+ return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
+ }
}
+
// fold ((A+(B+or-C))-B) -> A+or-C
if (N0.getOpcode() == ISD::ADD &&
(N0.getOperand(1).getOpcode() == ISD::SUB ||
N0.getOperand(1).getOpcode() == ISD::ADD) &&
N0.getOperand(1).getOperand(0) == N1)
- return DAG.getNode(N0.getOperand(1).getOpcode(), SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1).getOperand(1));
+ return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
+ N0.getOperand(1).getOperand(1));
+
// fold ((A+(C+B))-B) -> A+C
- if (N0.getOpcode() == ISD::ADD &&
- N0.getOperand(1).getOpcode() == ISD::ADD &&
+ if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
N0.getOperand(1).getOperand(1) == N1)
- return DAG.getNode(ISD::ADD, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1).getOperand(0));
+ return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
+ N0.getOperand(1).getOperand(0));
+
// fold ((A-(B-C))-C) -> A-B
- if (N0.getOpcode() == ISD::SUB &&
- N0.getOperand(1).getOpcode() == ISD::SUB &&
+ if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
N0.getOperand(1).getOperand(1) == N1)
- return DAG.getNode(ISD::SUB, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1).getOperand(0));
+ return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
+ N0.getOperand(1).getOperand(0));
+
+ // fold (A-(B-C)) -> A+(C-B)
+ if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
+ return DAG.getNode(ISD::ADD, DL, VT, N0,
+ DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
+ N1.getOperand(0)));
+
+ // fold (X - (-Y * Z)) -> (X + (Y * Z))
+ if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
+ if (N1.getOperand(0).getOpcode() == ISD::SUB &&
+ isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
+ N1.getOperand(0).getOperand(1),
+ N1.getOperand(1));
+ return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
+ }
+ if (N1.getOperand(1).getOpcode() == ISD::SUB &&
+ isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
+ N1.getOperand(0),
+ N1.getOperand(1).getOperand(1));
+ return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
+ }
+ }
// If either operand of a sub is undef, the result is undef
- if (N0.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef())
return N0;
- if (N1.getOpcode() == ISD::UNDEF)
+ if (N1.isUndef())
return N1;
+ if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
+ return V;
+
+ if (SDValue V = foldAddSubOfSignBit(N, DAG))
+ return V;
+
+ // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
+ if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
+ if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
+ SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
+ SDValue S0 = N1.getOperand(0);
+ if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) {
+ unsigned OpSizeInBits = VT.getScalarSizeInBits();
+ if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
+ if (C->getAPIntValue() == (OpSizeInBits - 1))
+ return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
+ }
+ }
+ }
+
// If the relocation model supports it, consider symbol offsets.
if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
@@ -1945,25 +2749,72 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
if (N1C && GA->getOpcode() == ISD::GlobalAddress)
return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
GA->getOffset() -
- (uint64_t)N1C->getSExtValue());
+ (uint64_t)N1C->getSExtValue());
// fold (sub Sym+c1, Sym+c2) -> c1-c2
if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
if (GA->getGlobal() == GB->getGlobal())
return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
- SDLoc(N), VT);
+ DL, VT);
}
// sub X, (sextinreg Y i1) -> add X, (and Y 1)
if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
if (TN->getVT() == MVT::i1) {
- SDLoc DL(N);
SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
DAG.getConstant(1, DL, VT));
return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
}
}
+ // Prefer an add for more folding potential and possibly better codegen:
+ // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
+ if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
+ SDValue ShAmt = N1.getOperand(1);
+ ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
+ if (ShAmtC && ShAmtC->getZExtValue() == N1.getScalarValueSizeInBits() - 1) {
+ SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
+ return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
+ }
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ SDLoc DL(N);
+
+ // fold vector ops
+ if (VT.isVector()) {
+ // TODO SimplifyVBinOp
+
+ // fold (sub_sat x, 0) -> x, vector edition
+ if (ISD::isBuildVectorAllZeros(N1.getNode()))
+ return N0;
+ }
+
+ // fold (sub_sat x, undef) -> 0
+ if (N0.isUndef() || N1.isUndef())
+ return DAG.getConstant(0, DL, VT);
+
+ // fold (sub_sat x, x) -> 0
+ if (N0 == N1)
+ return DAG.getConstant(0, DL, VT);
+
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
+ // fold (sub_sat c1, c2) -> c3
+ return DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, N0.getNode(),
+ N1.getNode());
+ }
+
+ // fold (sub_sat x, 0) -> x
+ if (isNullConstant(N1))
+ return N0;
+
return SDValue();
}
@@ -1995,6 +2846,38 @@ SDValue DAGCombiner::visitSUBC(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitUSUBO(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ if (VT.isVector())
+ return SDValue();
+
+ EVT CarryVT = N->getValueType(1);
+ SDLoc DL(N);
+
+ // If the flag result is dead, turn this into an SUB.
+ if (!N->hasAnyUseOfValue(1))
+ return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
+ DAG.getUNDEF(CarryVT));
+
+ // fold (usubo x, x) -> 0 + no borrow
+ if (N0 == N1)
+ return CombineTo(N, DAG.getConstant(0, DL, VT),
+ DAG.getConstant(0, DL, CarryVT));
+
+ // fold (usubo x, 0) -> x + no borrow
+ if (isNullConstant(N1))
+ return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
+
+ // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
+ if (isAllOnesConstant(N0))
+ return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
+ DAG.getConstant(0, DL, CarryVT));
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitSUBE(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -2007,13 +2890,28 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue CarryIn = N->getOperand(2);
+
+ // fold (subcarry x, y, false) -> (usubo x, y)
+ if (isNullConstant(CarryIn)) {
+ if (!LegalOperations ||
+ TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
+ return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
+ }
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitMUL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
// fold (mul x, undef) -> 0
- if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef() || N1.isUndef())
return DAG.getConstant(0, SDLoc(N), VT);
bool N0IsConst = false;
@@ -2026,8 +2924,14 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
if (SDValue FoldedVOp = SimplifyVBinOp(N))
return FoldedVOp;
- N0IsConst = isConstantSplatVector(N0.getNode(), ConstValue0);
- N1IsConst = isConstantSplatVector(N1.getNode(), ConstValue1);
+ N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0);
+ N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
+ assert((!N0IsConst ||
+ ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) &&
+ "Splat APInt should be element width");
+ assert((!N1IsConst ||
+ ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
+ "Splat APInt should be element width");
} else {
N0IsConst = isa<ConstantSDNode>(N0);
if (N0IsConst) {
@@ -2047,19 +2951,19 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
N0.getNode(), N1.getNode());
// canonicalize constant to RHS (vector doesn't have to splat)
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
- !isConstantIntBuildVectorOrConstantInt(N1))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ !DAG.isConstantIntBuildVectorOrConstantInt(N1))
return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
// fold (mul x, 0) -> 0
- if (N1IsConst && ConstValue1 == 0)
+ if (N1IsConst && ConstValue1.isNullValue())
return N1;
- // We require a splat of the entire scalar bit width for non-contiguous
- // bit patterns.
- bool IsFullSplat =
- ConstValue1.getBitWidth() == VT.getScalarType().getSizeInBits();
// fold (mul x, 1) -> x
- if (N1IsConst && ConstValue1 == 1 && IsFullSplat)
+ if (N1IsConst && ConstValue1.isOneValue())
return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (mul x, -1) -> 0-x
if (N1IsConst && ConstValue1.isAllOnesValue()) {
SDLoc DL(N);
@@ -2067,16 +2971,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
DAG.getConstant(0, DL, VT), N0);
}
// fold (mul x, (1 << c)) -> x << c
- if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isPowerOf2() &&
- IsFullSplat) {
+ if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
+ DAG.isKnownToBeAPowerOfTwo(N1) &&
+ (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
SDLoc DL(N);
- return DAG.getNode(ISD::SHL, DL, VT, N0,
- DAG.getConstant(ConstValue1.logBase2(), DL,
- getShiftAmountTy(N0.getValueType())));
+ SDValue LogBase2 = BuildLogBase2(N1, DL);
+ EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+ return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
}
// fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
- if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2() &&
- IsFullSplat) {
+ if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) {
unsigned Log2Val = (-ConstValue1).logBase2();
SDLoc DL(N);
// FIXME: If the input is something that is easily negated (e.g. a
@@ -2088,46 +2993,74 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
getShiftAmountTy(N0.getValueType()))));
}
- APInt Val;
+ // Try to transform multiply-by-(power-of-2 +/- 1) into shift and add/sub.
+ // mul x, (2^N + 1) --> add (shl x, N), x
+ // mul x, (2^N - 1) --> sub (shl x, N), x
+ // Examples: x * 33 --> (x << 5) + x
+ // x * 15 --> (x << 4) - x
+ // x * -33 --> -((x << 5) + x)
+ // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
+ if (N1IsConst && TLI.decomposeMulByConstant(VT, N1)) {
+ // TODO: We could handle more general decomposition of any constant by
+ // having the target set a limit on number of ops and making a
+ // callback to determine that sequence (similar to sqrt expansion).
+ unsigned MathOp = ISD::DELETED_NODE;
+ APInt MulC = ConstValue1.abs();
+ if ((MulC - 1).isPowerOf2())
+ MathOp = ISD::ADD;
+ else if ((MulC + 1).isPowerOf2())
+ MathOp = ISD::SUB;
+
+ if (MathOp != ISD::DELETED_NODE) {
+ unsigned ShAmt = MathOp == ISD::ADD ? (MulC - 1).logBase2()
+ : (MulC + 1).logBase2();
+ assert(ShAmt > 0 && ShAmt < VT.getScalarSizeInBits() &&
+ "Not expecting multiply-by-constant that could have simplified");
+ SDLoc DL(N);
+ SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, N0,
+ DAG.getConstant(ShAmt, DL, VT));
+ SDValue R = DAG.getNode(MathOp, DL, VT, Shl, N0);
+ if (ConstValue1.isNegative())
+ R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
+ return R;
+ }
+ }
+
// (mul (shl X, c1), c2) -> (mul X, c2 << c1)
- if (N1IsConst && N0.getOpcode() == ISD::SHL &&
- (isConstantSplatVector(N0.getOperand(1).getNode(), Val) ||
- isa<ConstantSDNode>(N0.getOperand(1)))) {
- SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT,
- N1, N0.getOperand(1));
- AddToWorklist(C3.getNode());
- return DAG.getNode(ISD::MUL, SDLoc(N), VT,
- N0.getOperand(0), C3);
+ if (N0.getOpcode() == ISD::SHL &&
+ isConstantOrConstantVector(N1, /* NoOpaques */ true) &&
+ isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) {
+ SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1));
+ if (isConstantOrConstantVector(C3))
+ return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3);
}
// Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
// use.
{
- SDValue Sh(nullptr,0), Y(nullptr,0);
+ SDValue Sh(nullptr, 0), Y(nullptr, 0);
+
// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
if (N0.getOpcode() == ISD::SHL &&
- (isConstantSplatVector(N0.getOperand(1).getNode(), Val) ||
- isa<ConstantSDNode>(N0.getOperand(1))) &&
+ isConstantOrConstantVector(N0.getOperand(1)) &&
N0.getNode()->hasOneUse()) {
Sh = N0; Y = N1;
} else if (N1.getOpcode() == ISD::SHL &&
- isa<ConstantSDNode>(N1.getOperand(1)) &&
+ isConstantOrConstantVector(N1.getOperand(1)) &&
N1.getNode()->hasOneUse()) {
Sh = N1; Y = N0;
}
if (Sh.getNode()) {
- SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT,
- Sh.getOperand(0), Y);
- return DAG.getNode(ISD::SHL, SDLoc(N), VT,
- Mul, Sh.getOperand(1));
+ SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y);
+ return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1));
}
}
// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
- if (isConstantIntBuildVectorOrConstantInt(N1) &&
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
N0.getOpcode() == ISD::ADD &&
- isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
+ DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
isMulAddWithConstProfitable(N, N0, N1))
return DAG.getNode(ISD::ADD, SDLoc(N), VT,
DAG.getNode(ISD::MUL, SDLoc(N0), VT,
@@ -2136,7 +3069,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
N0.getOperand(1), N1));
// reassociate mul
- if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1))
+ if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags()))
return RMUL;
return SDValue();
@@ -2146,7 +3079,10 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
const TargetLowering &TLI) {
RTLIB::Libcall LC;
- switch (Node->getSimpleValueType(0).SimpleTy) {
+ EVT NodeType = Node->getValueType(0);
+ if (!NodeType.isSimple())
+ return false;
+ switch (NodeType.getSimpleVT().SimpleTy) {
default: return false; // No libcall for vector types.
case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
@@ -2163,14 +3099,18 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) {
if (Node->use_empty())
return SDValue(); // This is a dead node, leave it alone.
+ unsigned Opcode = Node->getOpcode();
+ bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
+ unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
+
+ // DivMod lib calls can still work on non-legal types if using lib-calls.
EVT VT = Node->getValueType(0);
- if (!TLI.isTypeLegal(VT))
+ if (VT.isVector() || !VT.isInteger())
return SDValue();
- unsigned Opcode = Node->getOpcode();
- bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
+ if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
+ return SDValue();
- unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
// If DIVREM is going to get expanded into a libcall,
// but there is no libcall available, then don't combine.
if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
@@ -2195,7 +3135,8 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) {
for (SDNode::use_iterator UI = Op0.getNode()->use_begin(),
UE = Op0.getNode()->use_end(); UI != UE; ++UI) {
SDNode *User = *UI;
- if (User == Node || User->use_empty())
+ if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
+ User->use_empty())
continue;
// Convert the other matching node(s), too;
// otherwise, the DIVREM may get target-legalized into something
@@ -2224,10 +3165,57 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) {
return combined;
}
+static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ unsigned Opc = N->getOpcode();
+ bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
+
+ // X / undef -> undef
+ // X % undef -> undef
+ // X / 0 -> undef
+ // X % 0 -> undef
+ // NOTE: This includes vectors where any divisor element is zero/undef.
+ if (DAG.isUndef(Opc, {N0, N1}))
+ return DAG.getUNDEF(VT);
+
+ // undef / X -> 0
+ // undef % X -> 0
+ if (N0.isUndef())
+ return DAG.getConstant(0, DL, VT);
+
+ // 0 / X -> 0
+ // 0 % X -> 0
+ ConstantSDNode *N0C = isConstOrConstSplat(N0);
+ if (N0C && N0C->isNullValue())
+ return N0;
+
+ // X / X -> 1
+ // X % X -> 0
+ if (N0 == N1)
+ return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
+
+ // X / 1 -> X
+ // X % 1 -> 0
+ // If this is a boolean op (single-bit element type), we can't have
+ // division-by-zero or remainder-by-zero, so assume the divisor is 1.
+ // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
+ // it's a 1.
+ if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
+ return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitSDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
+ EVT CCVT = getSetCCResultType(VT);
// fold vector ops
if (VT.isVector())
@@ -2241,85 +3229,129 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
ConstantSDNode *N1C = isConstOrConstSplat(N1);
if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C);
- // fold (sdiv X, 1) -> X
- if (N1C && N1C->isOne())
- return N0;
// fold (sdiv X, -1) -> 0-X
if (N1C && N1C->isAllOnesValue())
- return DAG.getNode(ISD::SUB, DL, VT,
- DAG.getConstant(0, DL, VT), N0);
+ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
+ // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
+ if (N1C && N1C->getAPIntValue().isMinSignedValue())
+ return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
+ DAG.getConstant(1, DL, VT),
+ DAG.getConstant(0, DL, VT));
+
+ if (SDValue V = simplifyDivRem(N, DAG))
+ return V;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
// If we know the sign bits of both operands are zero, strength reduce to a
// udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
- if (!VT.isVector()) {
- if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
- return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
+ if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
+ return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
+
+ if (SDValue V = visitSDIVLike(N0, N1, N)) {
+ // If the corresponding remainder node exists, update its users with
+ // (Dividend - (Quotient * Divisor).
+ if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
+ { N0, N1 })) {
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
+ SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
+ AddToWorklist(Mul.getNode());
+ AddToWorklist(Sub.getNode());
+ CombineTo(RemNode, Sub);
+ }
+ return V;
}
+ // sdiv, srem -> sdivrem
+ // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
+ // true. Otherwise, we break the simplification logic in visitREM().
+ AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+ if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
+ if (SDValue DivRem = useDivRem(N))
+ return DivRem;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ EVT CCVT = getSetCCResultType(VT);
+ unsigned BitWidth = VT.getScalarSizeInBits();
+
+ // Helper for determining whether a value is a power-2 constant scalar or a
+ // vector of such elements.
+ auto IsPowerOfTwo = [](ConstantSDNode *C) {
+ if (C->isNullValue() || C->isOpaque())
+ return false;
+ if (C->getAPIntValue().isPowerOf2())
+ return true;
+ if ((-C->getAPIntValue()).isPowerOf2())
+ return true;
+ return false;
+ };
+
// fold (sdiv X, pow2) -> simple ops after legalize
// FIXME: We check for the exact bit here because the generic lowering gives
// better results in that case. The target-specific lowering should learn how
// to handle exact sdivs efficiently.
- if (N1C && !N1C->isNullValue() && !N1C->isOpaque() &&
- !cast<BinaryWithFlagsSDNode>(N)->Flags.hasExact() &&
- (N1C->getAPIntValue().isPowerOf2() ||
- (-N1C->getAPIntValue()).isPowerOf2())) {
+ if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) {
// Target-specific implementation of sdiv x, pow2.
if (SDValue Res = BuildSDIVPow2(N))
return Res;
- unsigned lg2 = N1C->getAPIntValue().countTrailingZeros();
+ // Create constants that are functions of the shift amount value.
+ EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
+ SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
+ SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
+ C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
+ SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
+ if (!isConstantOrConstantVector(Inexact))
+ return SDValue();
// Splat the sign bit into the register
- SDValue SGN =
- DAG.getNode(ISD::SRA, DL, VT, N0,
- DAG.getConstant(VT.getScalarSizeInBits() - 1, DL,
- getShiftAmountTy(N0.getValueType())));
- AddToWorklist(SGN.getNode());
+ SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
+ DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
+ AddToWorklist(Sign.getNode());
// Add (N0 < 0) ? abs2 - 1 : 0;
- SDValue SRL =
- DAG.getNode(ISD::SRL, DL, VT, SGN,
- DAG.getConstant(VT.getScalarSizeInBits() - lg2, DL,
- getShiftAmountTy(SGN.getValueType())));
- SDValue ADD = DAG.getNode(ISD::ADD, DL, VT, N0, SRL);
- AddToWorklist(SRL.getNode());
- AddToWorklist(ADD.getNode()); // Divide by pow2
- SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, ADD,
- DAG.getConstant(lg2, DL,
- getShiftAmountTy(ADD.getValueType())));
-
- // If we're dividing by a positive value, we're done. Otherwise, we must
- // negate the result.
- if (N1C->getAPIntValue().isNonNegative())
- return SRA;
-
- AddToWorklist(SRA.getNode());
- return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA);
+ SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
+ AddToWorklist(Srl.getNode());
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
+ AddToWorklist(Add.getNode());
+ SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
+ AddToWorklist(Sra.getNode());
+
+ // Special case: (sdiv X, 1) -> X
+ // Special Case: (sdiv X, -1) -> 0-X
+ SDValue One = DAG.getConstant(1, DL, VT);
+ SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
+ SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
+ SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
+ SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
+ Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
+
+ // If dividing by a positive value, we're done. Otherwise, the result must
+ // be negated.
+ SDValue Zero = DAG.getConstant(0, DL, VT);
+ SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
+
+ // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
+ SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
+ SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
+ return Res;
}
// If integer divide is expensive and we satisfy the requirements, emit an
// alternate sequence. Targets may check function attributes for size/speed
// trade-offs.
- AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes();
- if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr))
+ AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+ if (isConstantOrConstantVector(N1) &&
+ !TLI.isIntDivCheap(N->getValueType(0), Attr))
if (SDValue Op = BuildSDIV(N))
return Op;
- // sdiv, srem -> sdivrem
- // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true.
- // Otherwise, we break the simplification logic in visitREM().
- if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
- if (SDValue DivRem = useDivRem(N))
- return DivRem;
-
- // undef / X -> 0
- if (N0.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, DL, VT);
- // X / undef -> undef
- if (N1.getOpcode() == ISD::UNDEF)
- return N1;
-
return SDValue();
}
@@ -2327,6 +3359,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
+ EVT CCVT = getSetCCResultType(VT);
// fold vector ops
if (VT.isVector())
@@ -2342,48 +3375,83 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT,
N0C, N1C))
return Folded;
+ // fold (udiv X, -1) -> select(X == -1, 1, 0)
+ if (N1C && N1C->getAPIntValue().isAllOnesValue())
+ return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
+ DAG.getConstant(1, DL, VT),
+ DAG.getConstant(0, DL, VT));
+
+ if (SDValue V = simplifyDivRem(N, DAG))
+ return V;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
+ if (SDValue V = visitUDIVLike(N0, N1, N)) {
+ // If the corresponding remainder node exists, update its users with
+ // (Dividend - (Quotient * Divisor).
+ if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
+ { N0, N1 })) {
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
+ SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
+ AddToWorklist(Mul.getNode());
+ AddToWorklist(Sub.getNode());
+ CombineTo(RemNode, Sub);
+ }
+ return V;
+ }
+
+ // sdiv, srem -> sdivrem
+ // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
+ // true. Otherwise, we break the simplification logic in visitREM().
+ AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+ if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
+ if (SDValue DivRem = useDivRem(N))
+ return DivRem;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+
// fold (udiv x, (1 << c)) -> x >>u c
- if (N1C && !N1C->isOpaque() && N1C->getAPIntValue().isPowerOf2())
- return DAG.getNode(ISD::SRL, DL, VT, N0,
- DAG.getConstant(N1C->getAPIntValue().logBase2(), DL,
- getShiftAmountTy(N0.getValueType())));
+ if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
+ DAG.isKnownToBeAPowerOfTwo(N1)) {
+ SDValue LogBase2 = BuildLogBase2(N1, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+ AddToWorklist(Trunc.getNode());
+ return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
+ }
// fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
if (N1.getOpcode() == ISD::SHL) {
- if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) {
- if (SHC->getAPIntValue().isPowerOf2()) {
- EVT ADDVT = N1.getOperand(1).getValueType();
- SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT,
- N1.getOperand(1),
- DAG.getConstant(SHC->getAPIntValue()
- .logBase2(),
- DL, ADDVT));
- AddToWorklist(Add.getNode());
- return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
- }
+ SDValue N10 = N1.getOperand(0);
+ if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
+ DAG.isKnownToBeAPowerOfTwo(N10)) {
+ SDValue LogBase2 = BuildLogBase2(N10, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ADDVT = N1.getOperand(1).getValueType();
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+ AddToWorklist(Trunc.getNode());
+ SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
+ AddToWorklist(Add.getNode());
+ return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
}
}
// fold (udiv x, c) -> alternate
- AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes();
- if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr))
+ AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+ if (isConstantOrConstantVector(N1) &&
+ !TLI.isIntDivCheap(N->getValueType(0), Attr))
if (SDValue Op = BuildUDIV(N))
return Op;
- // sdiv, srem -> sdivrem
- // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true.
- // Otherwise, we break the simplification logic in visitREM().
- if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
- if (SDValue DivRem = useDivRem(N))
- return DivRem;
-
- // undef / X -> 0
- if (N0.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, DL, VT);
- // X / undef -> undef
- if (N1.getOpcode() == ISD::UNDEF)
- return N1;
-
return SDValue();
}
@@ -2393,6 +3461,8 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
+ EVT CCVT = getSetCCResultType(VT);
+
bool isSigned = (Opcode == ISD::SREM);
SDLoc DL(N);
@@ -2402,56 +3472,60 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
if (N0C && N1C)
if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C))
return Folded;
+ // fold (urem X, -1) -> select(X == -1, 0, x)
+ if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue())
+ return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
+ DAG.getConstant(0, DL, VT), N0);
+
+ if (SDValue V = simplifyDivRem(N, DAG))
+ return V;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
if (isSigned) {
// If we know the sign bits of both operands are zero, strength reduce to a
// urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
- if (!VT.isVector()) {
- if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
- return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
- }
+ if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
+ return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
} else {
- // fold (urem x, pow2) -> (and x, pow2-1)
- if (N1C && !N1C->isNullValue() && !N1C->isOpaque() &&
- N1C->getAPIntValue().isPowerOf2()) {
- return DAG.getNode(ISD::AND, DL, VT, N0,
- DAG.getConstant(N1C->getAPIntValue() - 1, DL, VT));
- }
- // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
- if (N1.getOpcode() == ISD::SHL) {
- if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) {
- if (SHC->getAPIntValue().isPowerOf2()) {
- SDValue Add =
- DAG.getNode(ISD::ADD, DL, VT, N1,
- DAG.getConstant(APInt::getAllOnesValue(VT.getSizeInBits()), DL,
- VT));
- AddToWorklist(Add.getNode());
- return DAG.getNode(ISD::AND, DL, VT, N0, Add);
- }
- }
+ SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
+ if (DAG.isKnownToBeAPowerOfTwo(N1)) {
+ // fold (urem x, pow2) -> (and x, pow2-1)
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
+ AddToWorklist(Add.getNode());
+ return DAG.getNode(ISD::AND, DL, VT, N0, Add);
+ }
+ if (N1.getOpcode() == ISD::SHL &&
+ DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
+ // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
+ AddToWorklist(Add.getNode());
+ return DAG.getNode(ISD::AND, DL, VT, N0, Add);
}
}
- AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes();
+ AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
// If X/C can be simplified by the division-by-constant logic, lower
// X%C to the equivalent of X-X/C*C.
- // To avoid mangling nodes, this simplification requires that the combine()
- // call for the speculative DIV must not cause a DIVREM conversion. We guard
- // against this by skipping the simplification if isIntDivCheap(). When
- // div is not cheap, combine will not return a DIVREM. Regardless,
- // checking cheapness here makes sense since the simplification results in
- // fatter code.
- if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap(VT, Attr)) {
- unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
- SDValue Div = DAG.getNode(DivOpcode, DL, VT, N0, N1);
- AddToWorklist(Div.getNode());
- SDValue OptimizedDiv = combine(Div.getNode());
- if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) {
- assert((OptimizedDiv.getOpcode() != ISD::UDIVREM) &&
- (OptimizedDiv.getOpcode() != ISD::SDIVREM));
+ // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
+ // speculative DIV must not cause a DIVREM conversion. We guard against this
+ // by skipping the simplification if isIntDivCheap(). When div is not cheap,
+ // combine will not return a DIVREM. Regardless, checking cheapness here
+ // makes sense since the simplification results in fatter code.
+ if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
+ SDValue OptimizedDiv =
+ isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
+ if (OptimizedDiv.getNode()) {
+ // If the equivalent Div node also exists, update its users.
+ unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
+ if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
+ { N0, N1 }))
+ CombineTo(DivNode, OptimizedDiv);
SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
+ AddToWorklist(OptimizedDiv.getNode());
AddToWorklist(Mul.getNode());
return Sub;
}
@@ -2461,13 +3535,6 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
if (SDValue DivRem = useDivRem(N))
return DivRem.getValue(1);
- // undef % X -> 0
- if (N0.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, DL, VT);
- // X % undef -> undef
- if (N1.getOpcode() == ISD::UNDEF)
- return N1;
-
return SDValue();
}
@@ -2477,20 +3544,26 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
+ if (VT.isVector()) {
+ // fold (mulhs x, 0) -> 0
+ if (ISD::isBuildVectorAllZeros(N1.getNode()))
+ return N1;
+ if (ISD::isBuildVectorAllZeros(N0.getNode()))
+ return N0;
+ }
+
// fold (mulhs x, 0) -> 0
if (isNullConstant(N1))
return N1;
// fold (mulhs x, 1) -> (sra x, size(x)-1)
- if (isOneConstant(N1)) {
- SDLoc DL(N);
+ if (isOneConstant(N1))
return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
- DAG.getConstant(N0.getValueType().getSizeInBits() - 1,
- DL,
+ DAG.getConstant(N0.getValueSizeInBits() - 1, DL,
getShiftAmountTy(N0.getValueType())));
- }
+
// fold (mulhs x, undef) -> 0
- if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, SDLoc(N), VT);
+ if (N0.isUndef() || N1.isUndef())
+ return DAG.getConstant(0, DL, VT);
// If the type twice as wide is legal, transform the mulhs to a wider multiply
// plus a shift.
@@ -2518,6 +3591,14 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
+ if (VT.isVector()) {
+ // fold (mulhu x, 0) -> 0
+ if (ISD::isBuildVectorAllZeros(N1.getNode()))
+ return N1;
+ if (ISD::isBuildVectorAllZeros(N0.getNode()))
+ return N0;
+ }
+
// fold (mulhu x, 0) -> 0
if (isNullConstant(N1))
return N1;
@@ -2525,9 +3606,22 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
if (isOneConstant(N1))
return DAG.getConstant(0, DL, N0.getValueType());
// fold (mulhu x, undef) -> 0
- if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef() || N1.isUndef())
return DAG.getConstant(0, DL, VT);
+ // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
+ if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
+ DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
+ SDLoc DL(N);
+ unsigned NumEltBits = VT.getScalarSizeInBits();
+ SDValue LogBase2 = BuildLogBase2(N1, DL);
+ SDValue SRLAmt = DAG.getNode(
+ ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
+ EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+ SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
+ return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
+ }
+
// If the type twice as wide is legal, transform the mulhu to a wider multiply
// plus a shift.
if (VT.isSimple() && !VT.isVector()) {
@@ -2555,18 +3649,16 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
unsigned HiOp) {
// If the high half is not needed, just compute the low half.
bool HiExists = N->hasAnyUseOfValue(1);
- if (!HiExists &&
- (!LegalOperations ||
- TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
+ if (!HiExists && (!LegalOperations ||
+ TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
return CombineTo(N, Res, Res);
}
// If the low half is not needed, just compute the high half.
bool LoExists = N->hasAnyUseOfValue(0);
- if (!LoExists &&
- (!LegalOperations ||
- TLI.isOperationLegal(HiOp, N->getValueType(1)))) {
+ if (!LoExists && (!LegalOperations ||
+ TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
return CombineTo(N, Res, Res);
}
@@ -2582,7 +3674,7 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
SDValue LoOpt = combine(Lo.getNode());
if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
(!LegalOperations ||
- TLI.isOperationLegal(LoOpt.getOpcode(), LoOpt.getValueType())))
+ TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
return CombineTo(N, LoOpt, LoOpt);
}
@@ -2592,7 +3684,7 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
SDValue HiOpt = combine(Hi.getNode());
if (HiOpt.getNode() && HiOpt != Hi &&
(!LegalOperations ||
- TLI.isOperationLegal(HiOpt.getOpcode(), HiOpt.getValueType())))
+ TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
return CombineTo(N, HiOpt, HiOpt);
}
@@ -2691,97 +3783,142 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
if (SDValue FoldedVOp = SimplifyVBinOp(N))
return FoldedVOp;
- // fold (add c1, c2) -> c1+c2
+ // fold operation with constant operands.
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
if (N0C && N1C)
return DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, N0C, N1C);
// canonicalize constant to RHS
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
- !isConstantIntBuildVectorOrConstantInt(N1))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ !DAG.isConstantIntBuildVectorOrConstantInt(N1))
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
+ // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
+ // Only do this if the current op isn't legal and the flipped is.
+ unsigned Opcode = N->getOpcode();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.isOperationLegal(Opcode, VT) &&
+ (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
+ (N1.isUndef() || DAG.SignBitIsZero(N1))) {
+ unsigned AltOpcode;
+ switch (Opcode) {
+ case ISD::SMIN: AltOpcode = ISD::UMIN; break;
+ case ISD::SMAX: AltOpcode = ISD::UMAX; break;
+ case ISD::UMIN: AltOpcode = ISD::SMIN; break;
+ case ISD::UMAX: AltOpcode = ISD::SMAX; break;
+ default: llvm_unreachable("Unknown MINMAX opcode");
+ }
+ if (TLI.isOperationLegal(AltOpcode, VT))
+ return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1);
+ }
+
return SDValue();
}
-/// If this is a binary operator with two operands of the same opcode, try to
-/// simplify it.
-SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) {
+/// If this is a bitwise logic instruction and both operands have the same
+/// opcode, try to sink the other opcode after the logic instruction.
+SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
EVT VT = N0.getValueType();
- assert(N0.getOpcode() == N1.getOpcode() && "Bad input!");
+ unsigned LogicOpcode = N->getOpcode();
+ unsigned HandOpcode = N0.getOpcode();
+ assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
+ LogicOpcode == ISD::XOR) && "Expected logic opcode");
+ assert(HandOpcode == N1.getOpcode() && "Bad input!");
// Bail early if none of these transforms apply.
- if (N0.getNode()->getNumOperands() == 0) return SDValue();
-
- // For each of OP in AND/OR/XOR:
- // fold (OP (zext x), (zext y)) -> (zext (OP x, y))
- // fold (OP (sext x), (sext y)) -> (sext (OP x, y))
- // fold (OP (aext x), (aext y)) -> (aext (OP x, y))
- // fold (OP (bswap x), (bswap y)) -> (bswap (OP x, y))
- // fold (OP (trunc x), (trunc y)) -> (trunc (OP x, y)) (if trunc isn't free)
- //
- // do not sink logical op inside of a vector extend, since it may combine
- // into a vsetcc.
- EVT Op0VT = N0.getOperand(0).getValueType();
- if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
- N0.getOpcode() == ISD::SIGN_EXTEND ||
- N0.getOpcode() == ISD::BSWAP ||
- // Avoid infinite looping with PromoteIntBinOp.
- (N0.getOpcode() == ISD::ANY_EXTEND &&
- (!LegalTypes || TLI.isTypeDesirableForOp(N->getOpcode(), Op0VT))) ||
- (N0.getOpcode() == ISD::TRUNCATE &&
- (!TLI.isZExtFree(VT, Op0VT) ||
- !TLI.isTruncateFree(Op0VT, VT)) &&
- TLI.isTypeLegal(Op0VT))) &&
- !VT.isVector() &&
- Op0VT == N1.getOperand(0).getValueType() &&
- (!LegalOperations || TLI.isOperationLegal(N->getOpcode(), Op0VT))) {
- SDValue ORNode = DAG.getNode(N->getOpcode(), SDLoc(N0),
- N0.getOperand(0).getValueType(),
- N0.getOperand(0), N1.getOperand(0));
- AddToWorklist(ORNode.getNode());
- return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, ORNode);
- }
-
- // For each of OP in SHL/SRL/SRA/AND...
- // fold (and (OP x, z), (OP y, z)) -> (OP (and x, y), z)
- // fold (or (OP x, z), (OP y, z)) -> (OP (or x, y), z)
- // fold (xor (OP x, z), (OP y, z)) -> (OP (xor x, y), z)
- if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL ||
- N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::AND) &&
+ if (N0.getNumOperands() == 0)
+ return SDValue();
+
+ // FIXME: We should check number of uses of the operands to not increase
+ // the instruction count for all transforms.
+
+ // Handle size-changing casts.
+ SDValue X = N0.getOperand(0);
+ SDValue Y = N1.getOperand(0);
+ EVT XVT = X.getValueType();
+ SDLoc DL(N);
+ if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
+ HandOpcode == ISD::SIGN_EXTEND) {
+ // If both operands have other uses, this transform would create extra
+ // instructions without eliminating anything.
+ if (!N0.hasOneUse() && !N1.hasOneUse())
+ return SDValue();
+ // We need matching integer source types.
+ if (XVT != Y.getValueType())
+ return SDValue();
+ // Don't create an illegal op during or after legalization. Don't ever
+ // create an unsupported vector op.
+ if ((VT.isVector() || LegalOperations) &&
+ !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
+ return SDValue();
+ // Avoid infinite looping with PromoteIntBinOp.
+ // TODO: Should we apply desirable/legal constraints to all opcodes?
+ if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
+ !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
+ return SDValue();
+ // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ return DAG.getNode(HandOpcode, DL, VT, Logic);
+ }
+
+ // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
+ if (HandOpcode == ISD::TRUNCATE) {
+ // If both operands have other uses, this transform would create extra
+ // instructions without eliminating anything.
+ if (!N0.hasOneUse() && !N1.hasOneUse())
+ return SDValue();
+ // We need matching source types.
+ if (XVT != Y.getValueType())
+ return SDValue();
+ // Don't create an illegal op during or after legalization.
+ if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
+ return SDValue();
+ // Be extra careful sinking truncate. If it's free, there's no benefit in
+ // widening a binop. Also, don't create a logic op on an illegal type.
+ if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
+ return SDValue();
+ if (!TLI.isTypeLegal(XVT))
+ return SDValue();
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ return DAG.getNode(HandOpcode, DL, VT, Logic);
+ }
+
+ // For binops SHL/SRL/SRA/AND:
+ // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
+ if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
+ HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
N0.getOperand(1) == N1.getOperand(1)) {
- SDValue ORNode = DAG.getNode(N->getOpcode(), SDLoc(N0),
- N0.getOperand(0).getValueType(),
- N0.getOperand(0), N1.getOperand(0));
- AddToWorklist(ORNode.getNode());
- return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
- ORNode, N0.getOperand(1));
+ // If either operand has other uses, this transform is not an improvement.
+ if (!N0.hasOneUse() || !N1.hasOneUse())
+ return SDValue();
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
+ }
+
+ // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
+ if (HandOpcode == ISD::BSWAP) {
+ // If either operand has other uses, this transform is not an improvement.
+ if (!N0.hasOneUse() || !N1.hasOneUse())
+ return SDValue();
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ return DAG.getNode(HandOpcode, DL, VT, Logic);
}
// Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
- // Only perform this optimization after type legalization and before
+ // Only perform this optimization up until type legalization, before
// LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
// adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
// we don't want to undo this promotion.
// We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
// on scalars.
- if ((N0.getOpcode() == ISD::BITCAST ||
- N0.getOpcode() == ISD::SCALAR_TO_VECTOR) &&
- Level == AfterLegalizeTypes) {
- SDValue In0 = N0.getOperand(0);
- SDValue In1 = N1.getOperand(0);
- EVT In0Ty = In0.getValueType();
- EVT In1Ty = In1.getValueType();
- SDLoc DL(N);
- // If both incoming values are integers, and the original types are the
- // same.
- if (In0Ty.isInteger() && In1Ty.isInteger() && In0Ty == In1Ty) {
- SDValue Op = DAG.getNode(N->getOpcode(), DL, In0Ty, In0, In1);
- SDValue BC = DAG.getNode(N0.getOpcode(), DL, VT, Op);
- AddToWorklist(Op.getNode());
- return BC;
+ if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
+ Level <= AfterLegalizeTypes) {
+ // Input types must be integer and the same.
+ if (XVT.isInteger() && XVT == Y.getValueType()) {
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ return DAG.getNode(HandOpcode, DL, VT, Logic);
}
}
@@ -2797,167 +3934,210 @@ SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) {
// If both shuffles use the same mask, and both shuffles have the same first
// or second operand, then it might still be profitable to move the shuffle
// after the xor/and/or operation.
- if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
- ShuffleVectorSDNode *SVN0 = cast<ShuffleVectorSDNode>(N0);
- ShuffleVectorSDNode *SVN1 = cast<ShuffleVectorSDNode>(N1);
-
- assert(N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
+ if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
+ auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
+ auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
+ assert(X.getValueType() == Y.getValueType() &&
"Inputs to shuffles are not the same type");
// Check that both shuffles use the same mask. The masks are known to be of
// the same length because the result vector type is the same.
// Check also that shuffles have only one use to avoid introducing extra
// instructions.
- if (SVN0->hasOneUse() && SVN1->hasOneUse() &&
- SVN0->getMask().equals(SVN1->getMask())) {
- SDValue ShOp = N0->getOperand(1);
-
- // Don't try to fold this node if it requires introducing a
- // build vector of all zeros that might be illegal at this stage.
- if (N->getOpcode() == ISD::XOR && ShOp.getOpcode() != ISD::UNDEF) {
- if (!LegalTypes)
- ShOp = DAG.getConstant(0, SDLoc(N), VT);
- else
- ShOp = SDValue();
- }
-
- // (AND (shuf (A, C), shuf (B, C)) -> shuf (AND (A, B), C)
- // (OR (shuf (A, C), shuf (B, C)) -> shuf (OR (A, B), C)
- // (XOR (shuf (A, C), shuf (B, C)) -> shuf (XOR (A, B), V_0)
- if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
- SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT,
- N0->getOperand(0), N1->getOperand(0));
- AddToWorklist(NewNode.getNode());
- return DAG.getVectorShuffle(VT, SDLoc(N), NewNode, ShOp,
- &SVN0->getMask()[0]);
- }
-
- // Don't try to fold this node if it requires introducing a
- // build vector of all zeros that might be illegal at this stage.
- ShOp = N0->getOperand(0);
- if (N->getOpcode() == ISD::XOR && ShOp.getOpcode() != ISD::UNDEF) {
- if (!LegalTypes)
- ShOp = DAG.getConstant(0, SDLoc(N), VT);
- else
- ShOp = SDValue();
- }
+ if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
+ !SVN0->getMask().equals(SVN1->getMask()))
+ return SDValue();
- // (AND (shuf (C, A), shuf (C, B)) -> shuf (C, AND (A, B))
- // (OR (shuf (C, A), shuf (C, B)) -> shuf (C, OR (A, B))
- // (XOR (shuf (C, A), shuf (C, B)) -> shuf (V_0, XOR (A, B))
- if (N0->getOperand(0) == N1->getOperand(0) && ShOp.getNode()) {
- SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT,
- N0->getOperand(1), N1->getOperand(1));
- AddToWorklist(NewNode.getNode());
- return DAG.getVectorShuffle(VT, SDLoc(N), ShOp, NewNode,
- &SVN0->getMask()[0]);
- }
+ // Don't try to fold this node if it requires introducing a
+ // build vector of all zeros that might be illegal at this stage.
+ SDValue ShOp = N0.getOperand(1);
+ if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
+ ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
+
+ // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
+ if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
+ N0.getOperand(0), N1.getOperand(0));
+ return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
+ }
+
+ // Don't try to fold this node if it requires introducing a
+ // build vector of all zeros that might be illegal at this stage.
+ ShOp = N0.getOperand(0);
+ if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
+ ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
+
+ // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
+ if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
+ N1.getOperand(1));
+ return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
}
}
return SDValue();
}
+/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
+SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
+ const SDLoc &DL) {
+ SDValue LL, LR, RL, RR, N0CC, N1CC;
+ if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
+ !isSetCCEquivalent(N1, RL, RR, N1CC))
+ return SDValue();
+
+ assert(N0.getValueType() == N1.getValueType() &&
+ "Unexpected operand types for bitwise logic op");
+ assert(LL.getValueType() == LR.getValueType() &&
+ RL.getValueType() == RR.getValueType() &&
+ "Unexpected operand types for setcc");
+
+ // If we're here post-legalization or the logic op type is not i1, the logic
+ // op type must match a setcc result type. Also, all folds require new
+ // operations on the left and right operands, so those types must match.
+ EVT VT = N0.getValueType();
+ EVT OpVT = LL.getValueType();
+ if (LegalOperations || VT.getScalarType() != MVT::i1)
+ if (VT != getSetCCResultType(OpVT))
+ return SDValue();
+ if (OpVT != RL.getValueType())
+ return SDValue();
+
+ ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
+ ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
+ bool IsInteger = OpVT.isInteger();
+ if (LR == RR && CC0 == CC1 && IsInteger) {
+ bool IsZero = isNullOrNullSplat(LR);
+ bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
+
+ // All bits clear?
+ bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
+ // All sign bits clear?
+ bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
+ // Any bits set?
+ bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
+ // Any sign bits set?
+ bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
+
+ // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
+ // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
+ // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
+ // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
+ if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
+ SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
+ AddToWorklist(Or.getNode());
+ return DAG.getSetCC(DL, VT, Or, LR, CC1);
+ }
+
+ // All bits set?
+ bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
+ // All sign bits set?
+ bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
+ // Any bits clear?
+ bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
+ // Any sign bits clear?
+ bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
+
+ // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
+ // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
+ // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
+ // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
+ if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
+ SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
+ AddToWorklist(And.getNode());
+ return DAG.getSetCC(DL, VT, And, LR, CC1);
+ }
+ }
+
+ // TODO: What is the 'or' equivalent of this fold?
+ // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
+ if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
+ IsInteger && CC0 == ISD::SETNE &&
+ ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
+ (isAllOnesConstant(LR) && isNullConstant(RR)))) {
+ SDValue One = DAG.getConstant(1, DL, OpVT);
+ SDValue Two = DAG.getConstant(2, DL, OpVT);
+ SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
+ AddToWorklist(Add.getNode());
+ return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
+ }
+
+ // Try more general transforms if the predicates match and the only user of
+ // the compares is the 'and' or 'or'.
+ if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
+ N0.hasOneUse() && N1.hasOneUse()) {
+ // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
+ // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
+ if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
+ SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
+ SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
+ SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
+ SDValue Zero = DAG.getConstant(0, DL, OpVT);
+ return DAG.getSetCC(DL, VT, Or, Zero, CC1);
+ }
+ }
+
+ // Canonicalize equivalent operands to LL == RL.
+ if (LL == RR && LR == RL) {
+ CC1 = ISD::getSetCCSwappedOperands(CC1);
+ std::swap(RL, RR);
+ }
+
+ // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
+ // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
+ if (LL == RL && LR == RR) {
+ ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, IsInteger)
+ : ISD::getSetCCOrOperation(CC0, CC1, IsInteger);
+ if (NewCC != ISD::SETCC_INVALID &&
+ (!LegalOperations ||
+ (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
+ TLI.isOperationLegal(ISD::SETCC, OpVT))))
+ return DAG.getSetCC(DL, VT, LL, LR, NewCC);
+ }
+
+ return SDValue();
+}
+
/// This contains all DAGCombine rules which reduce two values combined by
/// an And operation to a single value. This makes them reusable in the context
/// of visitSELECT(). Rules involving constants are not included as
/// visitSELECT() already handles those cases.
-SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1,
- SDNode *LocReference) {
+SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
EVT VT = N1.getValueType();
+ SDLoc DL(N);
// fold (and x, undef) -> 0
- if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, SDLoc(LocReference), VT);
- // fold (and (setcc x), (setcc y)) -> (setcc (and x, y))
- SDValue LL, LR, RL, RR, CC0, CC1;
- if (isSetCCEquivalent(N0, LL, LR, CC0) && isSetCCEquivalent(N1, RL, RR, CC1)){
- ISD::CondCode Op0 = cast<CondCodeSDNode>(CC0)->get();
- ISD::CondCode Op1 = cast<CondCodeSDNode>(CC1)->get();
-
- if (LR == RR && isa<ConstantSDNode>(LR) && Op0 == Op1 &&
- LL.getValueType().isInteger()) {
- // fold (and (seteq X, 0), (seteq Y, 0)) -> (seteq (or X, Y), 0)
- if (isNullConstant(LR) && Op1 == ISD::SETEQ) {
- SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(N0),
- LR.getValueType(), LL, RL);
- AddToWorklist(ORNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1);
- }
- if (isAllOnesConstant(LR)) {
- // fold (and (seteq X, -1), (seteq Y, -1)) -> (seteq (and X, Y), -1)
- if (Op1 == ISD::SETEQ) {
- SDValue ANDNode = DAG.getNode(ISD::AND, SDLoc(N0),
- LR.getValueType(), LL, RL);
- AddToWorklist(ANDNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ANDNode, LR, Op1);
- }
- // fold (and (setgt X, -1), (setgt Y, -1)) -> (setgt (or X, Y), -1)
- if (Op1 == ISD::SETGT) {
- SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(N0),
- LR.getValueType(), LL, RL);
- AddToWorklist(ORNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1);
- }
- }
- }
- // Simplify (and (setne X, 0), (setne X, -1)) -> (setuge (add X, 1), 2)
- if (LL == RL && isa<ConstantSDNode>(LR) && isa<ConstantSDNode>(RR) &&
- Op0 == Op1 && LL.getValueType().isInteger() &&
- Op0 == ISD::SETNE && ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
- (isAllOnesConstant(LR) && isNullConstant(RR)))) {
- SDLoc DL(N0);
- SDValue ADDNode = DAG.getNode(ISD::ADD, DL, LL.getValueType(),
- LL, DAG.getConstant(1, DL,
- LL.getValueType()));
- AddToWorklist(ADDNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ADDNode,
- DAG.getConstant(2, DL, LL.getValueType()),
- ISD::SETUGE);
- }
- // canonicalize equivalent to ll == rl
- if (LL == RR && LR == RL) {
- Op1 = ISD::getSetCCSwappedOperands(Op1);
- std::swap(RL, RR);
- }
- if (LL == RL && LR == RR) {
- bool isInteger = LL.getValueType().isInteger();
- ISD::CondCode Result = ISD::getSetCCAndOperation(Op0, Op1, isInteger);
- if (Result != ISD::SETCC_INVALID &&
- (!LegalOperations ||
- (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) &&
- TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) {
- EVT CCVT = getSetCCResultType(LL.getValueType());
- if (N0.getValueType() == CCVT ||
- (!LegalOperations && N0.getValueType() == MVT::i1))
- return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(),
- LL, LR, Result);
- }
- }
- }
+ if (N0.isUndef() || N1.isUndef())
+ return DAG.getConstant(0, DL, VT);
+
+ if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
+ return V;
if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
VT.getSizeInBits() <= 64) {
if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
- APInt ADDC = ADDI->getAPIntValue();
- if (!TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
+ if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
// Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
// immediate for an add, but it is legal if its top c2 bits are set,
// transform the ADD so the immediate doesn't need to be materialized
// in a register.
- if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
+ APInt ADDC = ADDI->getAPIntValue();
+ APInt SRLC = SRLI->getAPIntValue();
+ if (ADDC.getMinSignedBits() <= 64 &&
+ SRLC.ult(VT.getSizeInBits()) &&
+ !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
- SRLI->getZExtValue());
+ SRLC.getZExtValue());
if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
ADDC |= Mask;
if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
- SDLoc DL(N0);
+ SDLoc DL0(N0);
SDValue NewAdd =
- DAG.getNode(ISD::ADD, DL, VT,
+ DAG.getNode(ISD::ADD, DL0, VT,
N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
CombineTo(N0.getNode(), NewAdd);
// Return N so it doesn't get rechecked!
- return SDValue(LocReference, 0);
+ return SDValue(N, 0);
}
}
}
@@ -2965,26 +4145,73 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1,
}
}
+ // Reduce bit extract of low half of an integer to the narrower type.
+ // (and (srl i64:x, K), KMask) ->
+ // (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
+ if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
+ if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
+ if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
+ unsigned Size = VT.getSizeInBits();
+ const APInt &AndMask = CAnd->getAPIntValue();
+ unsigned ShiftBits = CShift->getZExtValue();
+
+ // Bail out, this node will probably disappear anyway.
+ if (ShiftBits == 0)
+ return SDValue();
+
+ unsigned MaskBits = AndMask.countTrailingOnes();
+ EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
+
+ if (AndMask.isMask() &&
+ // Required bits must not span the two halves of the integer and
+ // must fit in the half size type.
+ (ShiftBits + MaskBits <= Size / 2) &&
+ TLI.isNarrowingProfitable(VT, HalfVT) &&
+ TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
+ TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
+ TLI.isTruncateFree(VT, HalfVT) &&
+ TLI.isZExtFree(HalfVT, VT)) {
+ // The isNarrowingProfitable is to avoid regressions on PPC and
+ // AArch64 which match a few 64-bit bit insert / bit extract patterns
+ // on downstream users of this. Those patterns could probably be
+ // extended to handle extensions mixed in.
+
+ SDValue SL(N0);
+ assert(MaskBits <= Size);
+
+ // Extracting the highest bit of the low half.
+ EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
+ N0.getOperand(0));
+
+ SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
+ SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
+ SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
+ SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
+ return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
+ }
+ }
+ }
+ }
+
return SDValue();
}
bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
- EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT,
- bool &NarrowLoad) {
- uint32_t ActiveBits = AndC->getAPIntValue().getActiveBits();
-
- if (ActiveBits == 0 || !APIntOps::isMask(ActiveBits, AndC->getAPIntValue()))
+ EVT LoadResultTy, EVT &ExtVT) {
+ if (!AndC->getAPIntValue().isMask())
return false;
+ unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
+
ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
- LoadedVT = LoadN->getMemoryVT();
+ EVT LoadedVT = LoadN->getMemoryVT();
if (ExtVT == LoadedVT &&
(!LegalOperations ||
TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
// ZEXTLOAD will match without needing to change the size of the value being
// loaded.
- NarrowLoad = false;
return true;
}
@@ -3004,15 +4231,306 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
return false;
- NarrowLoad = true;
return true;
}
+bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
+ ISD::LoadExtType ExtType, EVT &MemVT,
+ unsigned ShAmt) {
+ if (!LDST)
+ return false;
+ // Only allow byte offsets.
+ if (ShAmt % 8)
+ return false;
+
+ // Do not generate loads of non-round integer types since these can
+ // be expensive (and would be wrong if the type is not byte sized).
+ if (!MemVT.isRound())
+ return false;
+
+ // Don't change the width of a volatile load.
+ if (LDST->isVolatile())
+ return false;
+
+ // Verify that we are actually reducing a load width here.
+ if (LDST->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits())
+ return false;
+
+ // Ensure that this isn't going to produce an unsupported unaligned access.
+ if (ShAmt &&
+ !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
+ LDST->getAddressSpace(), ShAmt / 8))
+ return false;
+
+ // It's not possible to generate a constant of extended or untyped type.
+ EVT PtrType = LDST->getBasePtr().getValueType();
+ if (PtrType == MVT::Untyped || PtrType.isExtended())
+ return false;
+
+ if (isa<LoadSDNode>(LDST)) {
+ LoadSDNode *Load = cast<LoadSDNode>(LDST);
+ // Don't transform one with multiple uses, this would require adding a new
+ // load.
+ if (!SDValue(Load, 0).hasOneUse())
+ return false;
+
+ if (LegalOperations &&
+ !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
+ return false;
+
+ // For the transform to be legal, the load must produce only two values
+ // (the value loaded and the chain). Don't transform a pre-increment
+ // load, for example, which produces an extra value. Otherwise the
+ // transformation is not equivalent, and the downstream logic to replace
+ // uses gets things wrong.
+ if (Load->getNumValues() > 2)
+ return false;
+
+ // If the load that we're shrinking is an extload and we're not just
+ // discarding the extension we can't simply shrink the load. Bail.
+ // TODO: It would be possible to merge the extensions in some cases.
+ if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
+ Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
+ return false;
+
+ if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
+ return false;
+ } else {
+ assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
+ StoreSDNode *Store = cast<StoreSDNode>(LDST);
+ // Can't write outside the original store
+ if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
+ return false;
+
+ if (LegalOperations &&
+ !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
+ return false;
+ }
+ return true;
+}
+
+bool DAGCombiner::SearchForAndLoads(SDNode *N,
+ SmallVectorImpl<LoadSDNode*> &Loads,
+ SmallPtrSetImpl<SDNode*> &NodesWithConsts,
+ ConstantSDNode *Mask,
+ SDNode *&NodeToMask) {
+ // Recursively search for the operands, looking for loads which can be
+ // narrowed.
+ for (unsigned i = 0, e = N->getNumOperands(); i < e; ++i) {
+ SDValue Op = N->getOperand(i);
+
+ if (Op.getValueType().isVector())
+ return false;
+
+ // Some constants may need fixing up later if they are too large.
+ if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
+ if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
+ (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
+ NodesWithConsts.insert(N);
+ continue;
+ }
+
+ if (!Op.hasOneUse())
+ return false;
+
+ switch(Op.getOpcode()) {
+ case ISD::LOAD: {
+ auto *Load = cast<LoadSDNode>(Op);
+ EVT ExtVT;
+ if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
+ isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
+
+ // ZEXTLOAD is already small enough.
+ if (Load->getExtensionType() == ISD::ZEXTLOAD &&
+ ExtVT.bitsGE(Load->getMemoryVT()))
+ continue;
+
+ // Use LE to convert equal sized loads to zext.
+ if (ExtVT.bitsLE(Load->getMemoryVT()))
+ Loads.push_back(Load);
+
+ continue;
+ }
+ return false;
+ }
+ case ISD::ZERO_EXTEND:
+ case ISD::AssertZext: {
+ unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
+ EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
+ EVT VT = Op.getOpcode() == ISD::AssertZext ?
+ cast<VTSDNode>(Op.getOperand(1))->getVT() :
+ Op.getOperand(0).getValueType();
+
+ // We can accept extending nodes if the mask is wider or an equal
+ // width to the original type.
+ if (ExtVT.bitsGE(VT))
+ continue;
+ break;
+ }
+ case ISD::OR:
+ case ISD::XOR:
+ case ISD::AND:
+ if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
+ NodeToMask))
+ return false;
+ continue;
+ }
+
+ // Allow one node which will masked along with any loads found.
+ if (NodeToMask)
+ return false;
+
+ // Also ensure that the node to be masked only produces one data result.
+ NodeToMask = Op.getNode();
+ if (NodeToMask->getNumValues() > 1) {
+ bool HasValue = false;
+ for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
+ MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
+ if (VT != MVT::Glue && VT != MVT::Other) {
+ if (HasValue) {
+ NodeToMask = nullptr;
+ return false;
+ }
+ HasValue = true;
+ }
+ }
+ assert(HasValue && "Node to be masked has no data result?");
+ }
+ }
+ return true;
+}
+
+bool DAGCombiner::BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG) {
+ auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ if (!Mask)
+ return false;
+
+ if (!Mask->getAPIntValue().isMask())
+ return false;
+
+ // No need to do anything if the and directly uses a load.
+ if (isa<LoadSDNode>(N->getOperand(0)))
+ return false;
+
+ SmallVector<LoadSDNode*, 8> Loads;
+ SmallPtrSet<SDNode*, 2> NodesWithConsts;
+ SDNode *FixupNode = nullptr;
+ if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
+ if (Loads.size() == 0)
+ return false;
+
+ LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
+ SDValue MaskOp = N->getOperand(1);
+
+ // If it exists, fixup the single node we allow in the tree that needs
+ // masking.
+ if (FixupNode) {
+ LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
+ SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
+ FixupNode->getValueType(0),
+ SDValue(FixupNode, 0), MaskOp);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
+ if (And.getOpcode() == ISD ::AND)
+ DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
+ }
+
+ // Narrow any constants that need it.
+ for (auto *LogicN : NodesWithConsts) {
+ SDValue Op0 = LogicN->getOperand(0);
+ SDValue Op1 = LogicN->getOperand(1);
+
+ if (isa<ConstantSDNode>(Op0))
+ std::swap(Op0, Op1);
+
+ SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
+ Op1, MaskOp);
+
+ DAG.UpdateNodeOperands(LogicN, Op0, And);
+ }
+
+ // Create narrow loads.
+ for (auto *Load : Loads) {
+ LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
+ SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
+ SDValue(Load, 0), MaskOp);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
+ if (And.getOpcode() == ISD ::AND)
+ And = SDValue(
+ DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
+ SDValue NewLoad = ReduceLoadWidth(And.getNode());
+ assert(NewLoad &&
+ "Shouldn't be masking the load if it can't be narrowed");
+ CombineTo(Load, NewLoad, NewLoad.getValue(1));
+ }
+ DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
+ return true;
+ }
+ return false;
+}
+
+// Unfold
+// x & (-1 'logical shift' y)
+// To
+// (x 'opposite logical shift' y) 'logical shift' y
+// if it is better for performance.
+SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
+ assert(N->getOpcode() == ISD::AND);
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+
+ // Do we actually prefer shifts over mask?
+ if (!TLI.preferShiftsToClearExtremeBits(N0))
+ return SDValue();
+
+ // Try to match (-1 '[outer] logical shift' y)
+ unsigned OuterShift;
+ unsigned InnerShift; // The opposite direction to the OuterShift.
+ SDValue Y; // Shift amount.
+ auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
+ if (!M.hasOneUse())
+ return false;
+ OuterShift = M->getOpcode();
+ if (OuterShift == ISD::SHL)
+ InnerShift = ISD::SRL;
+ else if (OuterShift == ISD::SRL)
+ InnerShift = ISD::SHL;
+ else
+ return false;
+ if (!isAllOnesConstant(M->getOperand(0)))
+ return false;
+ Y = M->getOperand(1);
+ return true;
+ };
+
+ SDValue X;
+ if (matchMask(N1))
+ X = N0;
+ else if (matchMask(N0))
+ X = N1;
+ else
+ return SDValue();
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+
+ // tmp = x 'opposite logical shift' y
+ SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
+ // ret = tmp 'logical shift' y
+ SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
+
+ return T1;
+}
+
SDValue DAGCombiner::visitAND(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N1.getValueType();
+ // x & x --> x
+ if (N0 == N1)
+ return N0;
+
// fold vector ops
if (VT.isVector()) {
if (SDValue FoldedVOp = SimplifyVBinOp(N))
@@ -3021,16 +4539,12 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// fold (and x, 0) -> 0, vector edition
if (ISD::isBuildVectorAllZeros(N0.getNode()))
// do not return N0, because undef node may exist in N0
- return DAG.getConstant(
- APInt::getNullValue(
- N0.getValueType().getScalarType().getSizeInBits()),
- SDLoc(N), N0.getValueType());
+ return DAG.getConstant(APInt::getNullValue(N0.getScalarValueSizeInBits()),
+ SDLoc(N), N0.getValueType());
if (ISD::isBuildVectorAllZeros(N1.getNode()))
// do not return N1, because undef node may exist in N1
- return DAG.getConstant(
- APInt::getNullValue(
- N1.getValueType().getScalarType().getSizeInBits()),
- SDLoc(N), N1.getValueType());
+ return DAG.getConstant(APInt::getNullValue(N1.getScalarValueSizeInBits()),
+ SDLoc(N), N1.getValueType());
// fold (and x, -1) -> x, vector edition
if (ISD::isBuildVectorAllOnes(N0.getNode()))
@@ -3041,34 +4555,46 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// fold (and c1, c2) -> c1&c2
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C);
// canonicalize constant to RHS
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
- !isConstantIntBuildVectorOrConstantInt(N1))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ !DAG.isConstantIntBuildVectorOrConstantInt(N1))
return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
// fold (and x, -1) -> x
if (isAllOnesConstant(N1))
return N0;
// if (and x, c) is known to be zero, return 0
- unsigned BitWidth = VT.getScalarType().getSizeInBits();
+ unsigned BitWidth = VT.getScalarSizeInBits();
if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(BitWidth)))
return DAG.getConstant(0, SDLoc(N), VT);
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// reassociate and
- if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1))
+ if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
return RAND;
+
+ // Try to convert a constant mask AND into a shuffle clear mask.
+ if (VT.isVector())
+ if (SDValue Shuffle = XformToShuffleWithZero(N))
+ return Shuffle;
+
// fold (and (or x, C), D) -> D if (C & D) == D
- if (N1C && N0.getOpcode() == ISD::OR)
- if (ConstantSDNode *ORI = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
- if ((ORI->getAPIntValue() & N1C->getAPIntValue()) == N1C->getAPIntValue())
- return N1;
+ auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
+ return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
+ };
+ if (N0.getOpcode() == ISD::OR &&
+ ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
+ return N1;
// fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
SDValue N0Op0 = N0.getOperand(0);
APInt Mask = ~N1C->getAPIntValue();
- Mask = Mask.trunc(N0Op0.getValueSizeInBits());
+ Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
N0.getValueType(), N0Op0);
@@ -3090,8 +4616,10 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// the 'X' node here can either be nothing or an extract_vector_elt to catch
// more cases.
if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- N0.getOperand(0).getOpcode() == ISD::LOAD) ||
- N0.getOpcode() == ISD::LOAD) {
+ N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
+ N0.getOperand(0).getOpcode() == ISD::LOAD &&
+ N0.getOperand(0).getResNo() == 0) ||
+ (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
N0 : N0.getOperand(0) );
@@ -3117,7 +4645,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// that will apply equally to all members of the vector, so AND all the
// lanes of the constant together.
EVT VT = Vector->getValueType(0);
- unsigned BitWidth = VT.getVectorElementType().getSizeInBits();
+ unsigned BitWidth = VT.getScalarSizeInBits();
// If the splat value has been compressed to a bitlength lower
// than the size of the vector lane, we need to re-expand it to
@@ -3148,8 +4676,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// Resize the constant to the same size as the original memory access before
// extension. If it is still the AllOnesValue then this AND is completely
// unneeded.
- Constant =
- Constant.zextOrTrunc(Load->getMemoryVT().getScalarType().getSizeInBits());
+ Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
bool B;
switch (Load->getExtensionType()) {
@@ -3163,6 +4690,10 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
// preserve semantics once we get rid of the AND.
SDValue NewLoad(Load, 0);
+
+ // Fold the AND away. NewLoad may get replaced immediately.
+ CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
+
if (Load->getExtensionType() == ISD::EXTLOAD) {
NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
Load->getValueType(0), SDLoc(Load),
@@ -3180,10 +4711,6 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
}
}
- // Fold the AND away, taking care not to fold to the old load node if we
- // replaced it.
- CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
-
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
@@ -3191,60 +4718,25 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// fold (and (load x), 255) -> (zextload x, i8)
// fold (and (extload x, i16), 255) -> (zextload x, i8)
// fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
- if (N1C && (N0.getOpcode() == ISD::LOAD ||
- (N0.getOpcode() == ISD::ANY_EXTEND &&
- N0.getOperand(0).getOpcode() == ISD::LOAD))) {
- bool HasAnyExt = N0.getOpcode() == ISD::ANY_EXTEND;
- LoadSDNode *LN0 = HasAnyExt
- ? cast<LoadSDNode>(N0.getOperand(0))
- : cast<LoadSDNode>(N0);
- if (LN0->getExtensionType() != ISD::SEXTLOAD &&
- LN0->isUnindexed() && N0.hasOneUse() && SDValue(LN0, 0).hasOneUse()) {
- auto NarrowLoad = false;
- EVT LoadResultTy = HasAnyExt ? LN0->getValueType(0) : VT;
- EVT ExtVT, LoadedVT;
- if (isAndLoadExtLoad(N1C, LN0, LoadResultTy, ExtVT, LoadedVT,
- NarrowLoad)) {
- if (!NarrowLoad) {
- SDValue NewLoad =
- DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), LoadResultTy,
- LN0->getChain(), LN0->getBasePtr(), ExtVT,
- LN0->getMemOperand());
- AddToWorklist(N);
- CombineTo(LN0, NewLoad, NewLoad.getValue(1));
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- } else {
- EVT PtrType = LN0->getOperand(1).getValueType();
-
- unsigned Alignment = LN0->getAlignment();
- SDValue NewPtr = LN0->getBasePtr();
-
- // For big endian targets, we need to add an offset to the pointer
- // to load the correct bytes. For little endian systems, we merely
- // need to read fewer bytes from the same pointer.
- if (DAG.getDataLayout().isBigEndian()) {
- unsigned LVTStoreBytes = LoadedVT.getStoreSize();
- unsigned EVTStoreBytes = ExtVT.getStoreSize();
- unsigned PtrOff = LVTStoreBytes - EVTStoreBytes;
- SDLoc DL(LN0);
- NewPtr = DAG.getNode(ISD::ADD, DL, PtrType,
- NewPtr, DAG.getConstant(PtrOff, DL, PtrType));
- Alignment = MinAlign(Alignment, PtrOff);
- }
+ if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
+ (N0.getOpcode() == ISD::ANY_EXTEND &&
+ N0.getOperand(0).getOpcode() == ISD::LOAD))) {
+ if (SDValue Res = ReduceLoadWidth(N)) {
+ LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
+ ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
+ AddToWorklist(N);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res);
+ return SDValue(N, 0);
+ }
+ }
- AddToWorklist(NewPtr.getNode());
-
- SDValue Load =
- DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), LoadResultTy,
- LN0->getChain(), NewPtr,
- LN0->getPointerInfo(),
- ExtVT, LN0->isVolatile(), LN0->isNonTemporal(),
- LN0->isInvariant(), Alignment, LN0->getAAInfo());
- AddToWorklist(N);
- CombineTo(LN0, Load, Load.getValue(1));
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- }
- }
+ if (Level >= AfterLegalizeTypes) {
+ // Attempt to propagate the AND back up to the leaves which, if they're
+ // loads, can be combined to narrow loads and the AND node can be removed.
+ // Perform after legalization so that extend nodes will already be
+ // combined into the loads.
+ if (BackwardsPropagateMask(N, DAG)) {
+ return SDValue(N, 0);
}
}
@@ -3253,13 +4745,31 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// Simplify: (and (op x...), (op y...)) -> (op (and x, y))
if (N0.getOpcode() == N1.getOpcode())
- if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
- return Tmp;
+ if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
+ return V;
+
+ // Masking the negated extension of a boolean is just the zero-extended
+ // boolean:
+ // and (sub 0, zext(bool X)), 1 --> zext(bool X)
+ // and (sub 0, sext(bool X)), 1 --> zext(bool X)
+ //
+ // Note: the SimplifyDemandedBits fold below can make an information-losing
+ // transform, and then we have no way to find this better fold.
+ if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
+ if (isNullOrNullSplat(N0.getOperand(0))) {
+ SDValue SubRHS = N0.getOperand(1);
+ if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
+ SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
+ return SubRHS;
+ if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
+ SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
+ return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
+ }
+ }
// fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
// fold (and (sra)) -> (and (srl)) when possible.
- if (!VT.isVector() &&
- SimplifyDemandedBits(SDValue(N, 0)))
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
// fold (zext_inreg (extload x)) -> (zextload x)
@@ -3268,9 +4778,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
EVT MemVT = LN0->getMemoryVT();
// If we zero all the possible extended bits, then we can turn this into
// a zextload if we are running before legalize or the operation is legal.
- unsigned BitWidth = N1.getValueType().getScalarType().getSizeInBits();
+ unsigned BitWidth = N1.getScalarValueSizeInBits();
if (DAG.MaskedValueIsZero(N1, APInt::getHighBitsSet(BitWidth,
- BitWidth - MemVT.getScalarType().getSizeInBits())) &&
+ BitWidth - MemVT.getScalarSizeInBits())) &&
((!LegalOperations && !LN0->isVolatile()) ||
TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT,
@@ -3288,9 +4798,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
EVT MemVT = LN0->getMemoryVT();
// If we zero all the possible extended bits, then we can turn this into
// a zextload if we are running before legalize or the operation is legal.
- unsigned BitWidth = N1.getValueType().getScalarType().getSizeInBits();
+ unsigned BitWidth = N1.getScalarValueSizeInBits();
if (DAG.MaskedValueIsZero(N1, APInt::getHighBitsSet(BitWidth,
- BitWidth - MemVT.getScalarType().getSizeInBits())) &&
+ BitWidth - MemVT.getScalarSizeInBits())) &&
((!LegalOperations && !LN0->isVolatile()) ||
TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT,
@@ -3303,12 +4813,14 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
}
// fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
- SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
- N0.getOperand(1), false);
- if (BSwap.getNode())
+ if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
+ N0.getOperand(1), false))
return BSwap;
}
+ if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
+ return Shifts;
+
return SDValue();
}
@@ -3321,10 +4833,10 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
EVT VT = N->getValueType(0);
if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
return SDValue();
- if (!TLI.isOperationLegal(ISD::BSWAP, VT))
+ if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
return SDValue();
- // Recognize (and (shl a, 8), 0xff), (and (srl a, 8), 0xff00)
+ // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
bool LookPassAnd0 = false;
bool LookPassAnd1 = false;
if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
@@ -3335,7 +4847,10 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
if (!N0.getNode()->hasOneUse())
return SDValue();
ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
- if (!N01C || N01C->getZExtValue() != 0xFF00)
+ // Also handle 0xffff since the LHS is guaranteed to have zeros there.
+ // This is needed for X86.
+ if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
+ N01C->getZExtValue() != 0xFFFF))
return SDValue();
N0 = N0.getOperand(0);
LookPassAnd0 = true;
@@ -3355,8 +4870,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
std::swap(N0, N1);
if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
return SDValue();
- if (!N0.getNode()->hasOneUse() ||
- !N1.getNode()->hasOneUse())
+ if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse())
return SDValue();
ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
@@ -3383,7 +4897,10 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
if (!N10.getNode()->hasOneUse())
return SDValue();
ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
- if (!N101C || N101C->getZExtValue() != 0xFF00)
+ // Also allow 0xFFFF since the bits will be shifted out. This is needed
+ // for X86.
+ if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
+ N101C->getZExtValue() != 0xFFFF))
return SDValue();
N10 = N10.getOperand(0);
LookPassAnd1 = true;
@@ -3434,27 +4951,44 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
return false;
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
+ SDValue N0 = N.getOperand(0);
+ unsigned Opc0 = N0.getOpcode();
+ if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
+ return false;
+
+ ConstantSDNode *N1C = nullptr;
+ // SHL or SRL: look upstream for AND mask operand
+ if (Opc == ISD::AND)
+ N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
+ else if (Opc0 == ISD::AND)
+ N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!N1C)
return false;
- unsigned Num;
+ unsigned MaskByteOffset;
switch (N1C->getZExtValue()) {
default:
return false;
- case 0xFF: Num = 0; break;
- case 0xFF00: Num = 1; break;
- case 0xFF0000: Num = 2; break;
- case 0xFF000000: Num = 3; break;
+ case 0xFF: MaskByteOffset = 0; break;
+ case 0xFF00: MaskByteOffset = 1; break;
+ case 0xFFFF:
+ // In case demanded bits didn't clear the bits that will be shifted out.
+ // This is needed for X86.
+ if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
+ MaskByteOffset = 1;
+ break;
+ }
+ return false;
+ case 0xFF0000: MaskByteOffset = 2; break;
+ case 0xFF000000: MaskByteOffset = 3; break;
}
// Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
- SDValue N0 = N.getOperand(0);
if (Opc == ISD::AND) {
- if (Num == 0 || Num == 2) {
+ if (MaskByteOffset == 0 || MaskByteOffset == 2) {
// (x >> 8) & 0xff
// (x >> 8) & 0xff0000
- if (N0.getOpcode() != ISD::SRL)
+ if (Opc0 != ISD::SRL)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!C || C->getZExtValue() != 8)
@@ -3462,7 +4996,7 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
} else {
// (x << 8) & 0xff00
// (x << 8) & 0xff000000
- if (N0.getOpcode() != ISD::SHL)
+ if (Opc0 != ISD::SHL)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!C || C->getZExtValue() != 8)
@@ -3471,7 +5005,7 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
} else if (Opc == ISD::SHL) {
// (x & 0xff) << 8
// (x & 0xff0000) << 8
- if (Num != 0 && Num != 2)
+ if (MaskByteOffset != 0 && MaskByteOffset != 2)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
if (!C || C->getZExtValue() != 8)
@@ -3479,17 +5013,17 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
} else { // Opc == ISD::SRL
// (x & 0xff00) >> 8
// (x & 0xff000000) >> 8
- if (Num != 1 && Num != 3)
+ if (MaskByteOffset != 1 && MaskByteOffset != 3)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
if (!C || C->getZExtValue() != 8)
return false;
}
- if (Parts[Num])
+ if (Parts[MaskByteOffset])
return false;
- Parts[Num] = N0.getOperand(0).getNode();
+ Parts[MaskByteOffset] = N0.getOperand(0).getNode();
return true;
}
@@ -3506,7 +5040,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
EVT VT = N->getValueType(0);
if (VT != MVT::i32)
return SDValue();
- if (!TLI.isOperationLegal(ISD::BSWAP, VT))
+ if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
return SDValue();
// Look for either
@@ -3521,18 +5055,16 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
if (N1.getOpcode() == ISD::OR &&
N00.getNumOperands() == 2 && N01.getNumOperands() == 2) {
// (or (or (and), (and)), (or (and), (and)))
- SDValue N000 = N00.getOperand(0);
- if (!isBSwapHWordElement(N000, Parts))
+ if (!isBSwapHWordElement(N00, Parts))
return SDValue();
- SDValue N001 = N00.getOperand(1);
- if (!isBSwapHWordElement(N001, Parts))
+ if (!isBSwapHWordElement(N01, Parts))
return SDValue();
- SDValue N010 = N01.getOperand(0);
- if (!isBSwapHWordElement(N010, Parts))
+ SDValue N10 = N1.getOperand(0);
+ if (!isBSwapHWordElement(N10, Parts))
return SDValue();
- SDValue N011 = N01.getOperand(1);
- if (!isBSwapHWordElement(N011, Parts))
+ SDValue N11 = N1.getOperand(1);
+ if (!isBSwapHWordElement(N11, Parts))
return SDValue();
} else {
// (or (or (or (and), (and)), (and)), (and))
@@ -3572,59 +5104,16 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
/// This contains all DAGCombine rules which reduce two values combined by
/// an Or operation to a single value \see visitANDLike().
-SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) {
+SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
EVT VT = N1.getValueType();
+ SDLoc DL(N);
+
// fold (or x, undef) -> -1
- if (!LegalOperations &&
- (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)) {
- EVT EltVT = VT.isVector() ? VT.getVectorElementType() : VT;
- return DAG.getConstant(APInt::getAllOnesValue(EltVT.getSizeInBits()),
- SDLoc(LocReference), VT);
- }
- // fold (or (setcc x), (setcc y)) -> (setcc (or x, y))
- SDValue LL, LR, RL, RR, CC0, CC1;
- if (isSetCCEquivalent(N0, LL, LR, CC0) && isSetCCEquivalent(N1, RL, RR, CC1)){
- ISD::CondCode Op0 = cast<CondCodeSDNode>(CC0)->get();
- ISD::CondCode Op1 = cast<CondCodeSDNode>(CC1)->get();
-
- if (LR == RR && Op0 == Op1 && LL.getValueType().isInteger()) {
- // fold (or (setne X, 0), (setne Y, 0)) -> (setne (or X, Y), 0)
- // fold (or (setlt X, 0), (setlt Y, 0)) -> (setne (or X, Y), 0)
- if (isNullConstant(LR) && (Op1 == ISD::SETNE || Op1 == ISD::SETLT)) {
- SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(LR),
- LR.getValueType(), LL, RL);
- AddToWorklist(ORNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1);
- }
- // fold (or (setne X, -1), (setne Y, -1)) -> (setne (and X, Y), -1)
- // fold (or (setgt X, -1), (setgt Y -1)) -> (setgt (and X, Y), -1)
- if (isAllOnesConstant(LR) && (Op1 == ISD::SETNE || Op1 == ISD::SETGT)) {
- SDValue ANDNode = DAG.getNode(ISD::AND, SDLoc(LR),
- LR.getValueType(), LL, RL);
- AddToWorklist(ANDNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ANDNode, LR, Op1);
- }
- }
- // canonicalize equivalent to ll == rl
- if (LL == RR && LR == RL) {
- Op1 = ISD::getSetCCSwappedOperands(Op1);
- std::swap(RL, RR);
- }
- if (LL == RL && LR == RR) {
- bool isInteger = LL.getValueType().isInteger();
- ISD::CondCode Result = ISD::getSetCCOrOperation(Op0, Op1, isInteger);
- if (Result != ISD::SETCC_INVALID &&
- (!LegalOperations ||
- (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) &&
- TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) {
- EVT CCVT = getSetCCResultType(LL.getValueType());
- if (N0.getValueType() == CCVT ||
- (!LegalOperations && N0.getValueType() == MVT::i1))
- return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(),
- LL, LR, Result);
- }
- }
- }
+ if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
+ return DAG.getAllOnesConstant(DL, VT);
+
+ if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
+ return V;
// (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
@@ -3645,7 +5134,6 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) {
DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
N0.getOperand(0), N1.getOperand(0));
- SDLoc DL(LocReference);
return DAG.getNode(ISD::AND, DL, VT, X,
DAG.getConstant(LHSMask | RHSMask, DL, VT));
}
@@ -3661,7 +5149,7 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) {
(N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
N0.getOperand(1), N1.getOperand(1));
- return DAG.getNode(ISD::AND, SDLoc(LocReference), VT, N0.getOperand(0), X);
+ return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
}
return SDValue();
@@ -3672,6 +5160,10 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
SDValue N1 = N->getOperand(1);
EVT VT = N1.getValueType();
+ // x | x --> x
+ if (N0 == N1)
+ return N0;
+
// fold vector ops
if (VT.isVector()) {
if (SDValue FoldedVOp = SimplifyVBinOp(N))
@@ -3686,70 +5178,75 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
// fold (or x, -1) -> -1, vector edition
if (ISD::isBuildVectorAllOnes(N0.getNode()))
// do not return N0, because undef node may exist in N0
- return DAG.getConstant(
- APInt::getAllOnesValue(
- N0.getValueType().getScalarType().getSizeInBits()),
- SDLoc(N), N0.getValueType());
+ return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
if (ISD::isBuildVectorAllOnes(N1.getNode()))
// do not return N1, because undef node may exist in N1
- return DAG.getConstant(
- APInt::getAllOnesValue(
- N1.getValueType().getScalarType().getSizeInBits()),
- SDLoc(N), N1.getValueType());
+ return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
- // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask1)
- // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf B, A, Mask2)
+ // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
// Do this only if the resulting shuffle is legal.
if (isa<ShuffleVectorSDNode>(N0) &&
isa<ShuffleVectorSDNode>(N1) &&
// Avoid folding a node with illegal type.
- TLI.isTypeLegal(VT) &&
- N0->getOperand(1) == N1->getOperand(1) &&
- ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode())) {
- bool CanFold = true;
- unsigned NumElts = VT.getVectorNumElements();
- const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
- const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1);
- // We construct two shuffle masks:
- // - Mask1 is a shuffle mask for a shuffle with N0 as the first operand
- // and N1 as the second operand.
- // - Mask2 is a shuffle mask for a shuffle with N1 as the first operand
- // and N0 as the second operand.
- // We do this because OR is commutable and therefore there might be
- // two ways to fold this node into a shuffle.
- SmallVector<int,4> Mask1;
- SmallVector<int,4> Mask2;
-
- for (unsigned i = 0; i != NumElts && CanFold; ++i) {
- int M0 = SV0->getMaskElt(i);
- int M1 = SV1->getMaskElt(i);
-
- // Both shuffle indexes are undef. Propagate Undef.
- if (M0 < 0 && M1 < 0) {
- Mask1.push_back(M0);
- Mask2.push_back(M0);
- continue;
- }
+ TLI.isTypeLegal(VT)) {
+ bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
+ bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
+ bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
+ bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
+ // Ensure both shuffles have a zero input.
+ if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
+ assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
+ assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
+ const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
+ const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1);
+ bool CanFold = true;
+ int NumElts = VT.getVectorNumElements();
+ SmallVector<int, 4> Mask(NumElts);
+
+ for (int i = 0; i != NumElts; ++i) {
+ int M0 = SV0->getMaskElt(i);
+ int M1 = SV1->getMaskElt(i);
+
+ // Determine if either index is pointing to a zero vector.
+ bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
+ bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
+
+ // If one element is zero and the otherside is undef, keep undef.
+ // This also handles the case that both are undef.
+ if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) {
+ Mask[i] = -1;
+ continue;
+ }
- if (M0 < 0 || M1 < 0 ||
- (M0 < (int)NumElts && M1 < (int)NumElts) ||
- (M0 >= (int)NumElts && M1 >= (int)NumElts)) {
- CanFold = false;
- break;
+ // Make sure only one of the elements is zero.
+ if (M0Zero == M1Zero) {
+ CanFold = false;
+ break;
+ }
+
+ assert((M0 >= 0 || M1 >= 0) && "Undef index!");
+
+ // We have a zero and non-zero element. If the non-zero came from
+ // SV0 make the index a LHS index. If it came from SV1, make it
+ // a RHS index. We need to mod by NumElts because we don't care
+ // which operand it came from in the original shuffles.
+ Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
}
- Mask1.push_back(M0 < (int)NumElts ? M0 : M1 + NumElts);
- Mask2.push_back(M1 < (int)NumElts ? M1 : M0 + NumElts);
- }
+ if (CanFold) {
+ SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
+ SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
- if (CanFold) {
- // Fold this sequence only if the resulting shuffle is 'legal'.
- if (TLI.isShuffleMaskLegal(Mask1, VT))
- return DAG.getVectorShuffle(VT, SDLoc(N), N0->getOperand(0),
- N1->getOperand(0), &Mask1[0]);
- if (TLI.isShuffleMaskLegal(Mask2, VT))
- return DAG.getVectorShuffle(VT, SDLoc(N), N1->getOperand(0),
- N0->getOperand(0), &Mask2[0]);
+ bool LegalMask = TLI.isShuffleMaskLegal(Mask, VT);
+ if (!LegalMask) {
+ std::swap(NewLHS, NewRHS);
+ ShuffleVectorSDNode::commuteMask(Mask);
+ LegalMask = TLI.isShuffleMaskLegal(Mask, VT);
+ }
+
+ if (LegalMask)
+ return DAG.getVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS, Mask);
+ }
}
}
}
@@ -3760,8 +5257,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, N0C, N1C);
// canonicalize constant to RHS
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
- !isConstantIntBuildVectorOrConstantInt(N1))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ !DAG.isConstantIntBuildVectorOrConstantInt(N1))
return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
// fold (or x, 0) -> x
if (isNullConstant(N1))
@@ -3769,6 +5266,10 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
// fold (or x, -1) -> -1
if (isAllOnesConstant(N1))
return N1;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (or x, c) -> c iff (x & ~c) == 0
if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
return N1;
@@ -3783,58 +5284,177 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
return BSwap;
// reassociate or
- if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1))
+ if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
return ROR;
+
// Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
- // iff (c1 & c2) == 0.
- if (N1C && N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
- isa<ConstantSDNode>(N0.getOperand(1))) {
- ConstantSDNode *C1 = cast<ConstantSDNode>(N0.getOperand(1));
- if ((C1->getAPIntValue() & N1C->getAPIntValue()) != 0) {
- if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
- N1C, C1))
- return DAG.getNode(
- ISD::AND, SDLoc(N), VT,
- DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1), COR);
- return SDValue();
+ // iff (c1 & c2) != 0 or c1/c2 are undef.
+ auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
+ return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
+ };
+ if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
+ ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
+ if (SDValue COR = DAG.FoldConstantArithmetic(
+ ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) {
+ SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
+ AddToWorklist(IOR.getNode());
+ return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
}
}
+
// Simplify: (or (op x...), (op y...)) -> (op (or x, y))
if (N0.getOpcode() == N1.getOpcode())
- if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
- return Tmp;
+ if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
+ return V;
// See if this is some rotate idiom.
if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
return SDValue(Rot, 0);
+ if (SDValue Load = MatchLoadCombine(N))
+ return Load;
+
// Simplify the operands using demanded-bits information.
- if (!VT.isVector() &&
- SimplifyDemandedBits(SDValue(N, 0)))
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
return SDValue();
}
-/// Match "(X shl/srl V1) & V2" where V2 may not be present.
-static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) {
- if (Op.getOpcode() == ISD::AND) {
- if (isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
- Mask = Op.getOperand(1);
- Op = Op.getOperand(0);
- } else {
- return false;
- }
+static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
+ if (Op.getOpcode() == ISD::AND &&
+ DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
+ Mask = Op.getOperand(1);
+ return Op.getOperand(0);
}
+ return Op;
+}
+/// Match "(X shl/srl V1) & V2" where V2 may not be present.
+static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
+ SDValue &Mask) {
+ Op = stripConstantMask(DAG, Op, Mask);
if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
Shift = Op;
return true;
}
-
return false;
}
+/// Helper function for visitOR to extract the needed side of a rotate idiom
+/// from a shl/srl/mul/udiv. This is meant to handle cases where
+/// InstCombine merged some outside op with one of the shifts from
+/// the rotate pattern.
+/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
+/// Otherwise, returns an expansion of \p ExtractFrom based on the following
+/// patterns:
+///
+/// (or (mul v c0) (shrl (mul v c1) c2)):
+/// expands (mul v c0) -> (shl (mul v c1) c3)
+///
+/// (or (udiv v c0) (shl (udiv v c1) c2)):
+/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
+///
+/// (or (shl v c0) (shrl (shl v c1) c2)):
+/// expands (shl v c0) -> (shl (shl v c1) c3)
+///
+/// (or (shrl v c0) (shl (shrl v c1) c2)):
+/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
+///
+/// Such that in all cases, c3+c2==bitwidth(op v c1).
+static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
+ SDValue ExtractFrom, SDValue &Mask,
+ const SDLoc &DL) {
+ assert(OppShift && ExtractFrom && "Empty SDValue");
+ assert(
+ (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) &&
+ "Existing shift must be valid as a rotate half");
+
+ ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
+ // Preconditions:
+ // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
+ //
+ // Find opcode of the needed shift to be extracted from (op0 v c0).
+ unsigned Opcode = ISD::DELETED_NODE;
+ bool IsMulOrDiv = false;
+ // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
+ // opcode or its arithmetic (mul or udiv) variant.
+ auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
+ IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
+ if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
+ return false;
+ Opcode = NeededShift;
+ return true;
+ };
+ // op0 must be either the needed shift opcode or the mul/udiv equivalent
+ // that the needed shift can be extracted from.
+ if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
+ (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
+ return SDValue();
+
+ // op0 must be the same opcode on both sides, have the same LHS argument,
+ // and produce the same value type.
+ SDValue OppShiftLHS = OppShift.getOperand(0);
+ EVT ShiftedVT = OppShiftLHS.getValueType();
+ if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
+ OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
+ ShiftedVT != ExtractFrom.getValueType())
+ return SDValue();
+
+ // Amount of the existing shift.
+ ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
+ // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
+ ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
+ // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
+ ConstantSDNode *ExtractFromCst =
+ isConstOrConstSplat(ExtractFrom.getOperand(1));
+ // TODO: We should be able to handle non-uniform constant vectors for these values
+ // Check that we have constant values.
+ if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
+ !OppLHSCst || !OppLHSCst->getAPIntValue() ||
+ !ExtractFromCst || !ExtractFromCst->getAPIntValue())
+ return SDValue();
+
+ // Compute the shift amount we need to extract to complete the rotate.
+ const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
+ if (OppShiftCst->getAPIntValue().ugt(VTWidth))
+ return SDValue();
+ APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
+ // Normalize the bitwidth of the two mul/udiv/shift constant operands.
+ APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
+ APInt OppLHSAmt = OppLHSCst->getAPIntValue();
+ zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
+
+ // Now try extract the needed shift from the ExtractFrom op and see if the
+ // result matches up with the existing shift's LHS op.
+ if (IsMulOrDiv) {
+ // Op to extract from is a mul or udiv by a constant.
+ // Check:
+ // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
+ // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
+ const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
+ NeededShiftAmt.getZExtValue());
+ APInt ResultAmt;
+ APInt Rem;
+ APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
+ if (Rem != 0 || ResultAmt != OppLHSAmt)
+ return SDValue();
+ } else {
+ // Op to extract from is a shift by a constant.
+ // Check:
+ // c2 - (bitwidth(op0 v c0) - c1) == c0
+ if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
+ ExtractFromAmt.getBitWidth()))
+ return SDValue();
+ }
+
+ // Return the expanded shift op that should allow a rotate to be formed.
+ EVT ShiftVT = OppShift.getOperand(1).getValueType();
+ EVT ResVT = ExtractFrom.getValueType();
+ SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
+ return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
+}
+
// Return true if we can prove that, whenever Neg and Pos are both in the
// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
@@ -3844,7 +5464,8 @@ static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) {
// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
// in direction shift1 by Neg. The range [0, EltSize) means that we only need
// to consider shift amounts with defined behavior.
-static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) {
+static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
+ SelectionDAG &DAG) {
// If EltSize is a power of 2 then:
//
// (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
@@ -3879,9 +5500,12 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) {
unsigned MaskLoBits = 0;
if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
- if (NegC->getAPIntValue() == EltSize - 1) {
+ KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0));
+ unsigned Bits = Log2_64(EltSize);
+ if (NegC->getAPIntValue().getActiveBits() <= Bits &&
+ ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) {
Neg = Neg.getOperand(0);
- MaskLoBits = Log2_64(EltSize);
+ MaskLoBits = Bits;
}
}
}
@@ -3896,10 +5520,15 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) {
// On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
// Pos'. The truncation is redundant for the purpose of the equality.
- if (MaskLoBits && Pos.getOpcode() == ISD::AND)
- if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
- if (PosC->getAPIntValue() == EltSize - 1)
+ if (MaskLoBits && Pos.getOpcode() == ISD::AND) {
+ if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) {
+ KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0));
+ if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits &&
+ ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >=
+ MaskLoBits))
Pos = Pos.getOperand(0);
+ }
+ }
// The condition we need is now:
//
@@ -3946,7 +5575,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) {
SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
SDValue Neg, SDValue InnerPos,
SDValue InnerNeg, unsigned PosOpcode,
- unsigned NegOpcode, SDLoc DL) {
+ unsigned NegOpcode, const SDLoc &DL) {
// fold (or (shl x, (*ext y)),
// (srl x, (*ext (sub 32, y)))) ->
// (rotl x, y) or (rotr x, (sub 32, y))
@@ -3955,7 +5584,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
// (srl x, (*ext y))) ->
// (rotr x, y) or (rotl x, (sub 32, y))
EVT VT = Shifted.getValueType();
- if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) {
+ if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG)) {
bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
HasPos ? Pos : Neg).getNode();
@@ -3967,26 +5596,63 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
// MatchRotate - Handle an 'or' of two operands. If this is one of the many
// idioms for rotate, and if the target supports rotation instructions, generate
// a rot[lr].
-SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
+SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
// Must be a legal type. Expanded 'n promoted things won't work with rotates.
EVT VT = LHS.getValueType();
if (!TLI.isTypeLegal(VT)) return nullptr;
// The target must have at least one rotate flavor.
- bool HasROTL = TLI.isOperationLegalOrCustom(ISD::ROTL, VT);
- bool HasROTR = TLI.isOperationLegalOrCustom(ISD::ROTR, VT);
+ bool HasROTL = hasOperation(ISD::ROTL, VT);
+ bool HasROTR = hasOperation(ISD::ROTR, VT);
if (!HasROTL && !HasROTR) return nullptr;
+ // Check for truncated rotate.
+ if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
+ LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
+ assert(LHS.getValueType() == RHS.getValueType());
+ if (SDNode *Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
+ return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(),
+ SDValue(Rot, 0)).getNode();
+ }
+ }
+
// Match "(X shl/srl V1) & V2" where V2 may not be present.
SDValue LHSShift; // The shift.
SDValue LHSMask; // AND value if any.
- if (!MatchRotateHalf(LHS, LHSShift, LHSMask))
- return nullptr; // Not part of a rotate.
+ matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
SDValue RHSShift; // The shift.
SDValue RHSMask; // AND value if any.
- if (!MatchRotateHalf(RHS, RHSShift, RHSMask))
- return nullptr; // Not part of a rotate.
+ matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
+
+ // If neither side matched a rotate half, bail
+ if (!LHSShift && !RHSShift)
+ return nullptr;
+
+ // InstCombine may have combined a constant shl, srl, mul, or udiv with one
+ // side of the rotate, so try to handle that here. In all cases we need to
+ // pass the matched shift from the opposite side to compute the opcode and
+ // needed shift amount to extract. We still want to do this if both sides
+ // matched a rotate half because one half may be a potential overshift that
+ // can be broken down (ie if InstCombine merged two shl or srl ops into a
+ // single one).
+
+ // Have LHS side of the rotate, try to extract the needed shift from the RHS.
+ if (LHSShift)
+ if (SDValue NewRHSShift =
+ extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
+ RHSShift = NewRHSShift;
+ // Have RHS side of the rotate, try to extract the needed shift from the LHS.
+ if (RHSShift)
+ if (SDValue NewLHSShift =
+ extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
+ LHSShift = NewLHSShift;
+
+ // If a side is still missing, nothing else we can do.
+ if (!RHSShift || !LHSShift)
+ return nullptr;
+
+ // At this point we've matched or extracted a shift op on each side.
if (LHSShift.getOperand(0) != RHSShift.getOperand(0))
return nullptr; // Not shifting the same value.
@@ -4009,31 +5675,28 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
// fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
// fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
- if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) {
- uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue();
- uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue();
- if ((LShVal + RShVal) != EltSizeInBits)
- return nullptr;
-
+ auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
+ ConstantSDNode *RHS) {
+ return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
+ };
+ if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
// If there is an AND of either shifted operand, apply it to the result.
if (LHSMask.getNode() || RHSMask.getNode()) {
- APInt AllBits = APInt::getAllOnesValue(EltSizeInBits);
- SDValue Mask = DAG.getConstant(AllBits, DL, VT);
+ SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
+ SDValue Mask = AllOnes;
if (LHSMask.getNode()) {
- APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal);
+ SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
- DAG.getNode(ISD::OR, DL, VT, LHSMask,
- DAG.getConstant(RHSBits, DL, VT)));
+ DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
}
if (RHSMask.getNode()) {
- APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal);
+ SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
- DAG.getNode(ISD::OR, DL, VT, RHSMask,
- DAG.getConstant(LHSBits, DL, VT)));
+ DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
}
Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask);
@@ -4075,6 +5738,385 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
return nullptr;
}
+namespace {
+
+/// Represents known origin of an individual byte in load combine pattern. The
+/// value of the byte is either constant zero or comes from memory.
+struct ByteProvider {
+ // For constant zero providers Load is set to nullptr. For memory providers
+ // Load represents the node which loads the byte from memory.
+ // ByteOffset is the offset of the byte in the value produced by the load.
+ LoadSDNode *Load = nullptr;
+ unsigned ByteOffset = 0;
+
+ ByteProvider() = default;
+
+ static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
+ return ByteProvider(Load, ByteOffset);
+ }
+
+ static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
+
+ bool isConstantZero() const { return !Load; }
+ bool isMemory() const { return Load; }
+
+ bool operator==(const ByteProvider &Other) const {
+ return Other.Load == Load && Other.ByteOffset == ByteOffset;
+ }
+
+private:
+ ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
+ : Load(Load), ByteOffset(ByteOffset) {}
+};
+
+} // end anonymous namespace
+
+/// Recursively traverses the expression calculating the origin of the requested
+/// byte of the given value. Returns None if the provider can't be calculated.
+///
+/// For all the values except the root of the expression verifies that the value
+/// has exactly one use and if it's not true return None. This way if the origin
+/// of the byte is returned it's guaranteed that the values which contribute to
+/// the byte are not used outside of this expression.
+///
+/// Because the parts of the expression are not allowed to have more than one
+/// use this function iterates over trees, not DAGs. So it never visits the same
+/// node more than once.
+static const Optional<ByteProvider>
+calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
+ bool Root = false) {
+ // Typical i64 by i8 pattern requires recursion up to 8 calls depth
+ if (Depth == 10)
+ return None;
+
+ if (!Root && !Op.hasOneUse())
+ return None;
+
+ assert(Op.getValueType().isScalarInteger() && "can't handle other types");
+ unsigned BitWidth = Op.getValueSizeInBits();
+ if (BitWidth % 8 != 0)
+ return None;
+ unsigned ByteWidth = BitWidth / 8;
+ assert(Index < ByteWidth && "invalid index requested");
+ (void) ByteWidth;
+
+ switch (Op.getOpcode()) {
+ case ISD::OR: {
+ auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
+ if (!LHS)
+ return None;
+ auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
+ if (!RHS)
+ return None;
+
+ if (LHS->isConstantZero())
+ return RHS;
+ if (RHS->isConstantZero())
+ return LHS;
+ return None;
+ }
+ case ISD::SHL: {
+ auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
+ if (!ShiftOp)
+ return None;
+
+ uint64_t BitShift = ShiftOp->getZExtValue();
+ if (BitShift % 8 != 0)
+ return None;
+ uint64_t ByteShift = BitShift / 8;
+
+ return Index < ByteShift
+ ? ByteProvider::getConstantZero()
+ : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
+ Depth + 1);
+ }
+ case ISD::ANY_EXTEND:
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND: {
+ SDValue NarrowOp = Op->getOperand(0);
+ unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
+ if (NarrowBitWidth % 8 != 0)
+ return None;
+ uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+
+ if (Index >= NarrowByteWidth)
+ return Op.getOpcode() == ISD::ZERO_EXTEND
+ ? Optional<ByteProvider>(ByteProvider::getConstantZero())
+ : None;
+ return calculateByteProvider(NarrowOp, Index, Depth + 1);
+ }
+ case ISD::BSWAP:
+ return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
+ Depth + 1);
+ case ISD::LOAD: {
+ auto L = cast<LoadSDNode>(Op.getNode());
+ if (L->isVolatile() || L->isIndexed())
+ return None;
+
+ unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
+ if (NarrowBitWidth % 8 != 0)
+ return None;
+ uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+
+ if (Index >= NarrowByteWidth)
+ return L->getExtensionType() == ISD::ZEXTLOAD
+ ? Optional<ByteProvider>(ByteProvider::getConstantZero())
+ : None;
+ return ByteProvider::getMemory(L, Index);
+ }
+ }
+
+ return None;
+}
+
+/// Match a pattern where a wide type scalar value is loaded by several narrow
+/// loads and combined by shifts and ors. Fold it into a single load or a load
+/// and a BSWAP if the targets supports it.
+///
+/// Assuming little endian target:
+/// i8 *a = ...
+/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
+/// =>
+/// i32 val = *((i32)a)
+///
+/// i8 *a = ...
+/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
+/// =>
+/// i32 val = BSWAP(*((i32)a))
+///
+/// TODO: This rule matches complex patterns with OR node roots and doesn't
+/// interact well with the worklist mechanism. When a part of the pattern is
+/// updated (e.g. one of the loads) its direct users are put into the worklist,
+/// but the root node of the pattern which triggers the load combine is not
+/// necessarily a direct user of the changed node. For example, once the address
+/// of t28 load is reassociated load combine won't be triggered:
+/// t25: i32 = add t4, Constant:i32<2>
+/// t26: i64 = sign_extend t25
+/// t27: i64 = add t2, t26
+/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
+/// t29: i32 = zero_extend t28
+/// t32: i32 = shl t29, Constant:i8<8>
+/// t33: i32 = or t23, t32
+/// As a possible fix visitLoad can check if the load can be a part of a load
+/// combine pattern and add corresponding OR roots to the worklist.
+SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
+ assert(N->getOpcode() == ISD::OR &&
+ "Can only match load combining against OR nodes");
+
+ // Handles simple types only
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+ unsigned ByteWidth = VT.getSizeInBits() / 8;
+
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ // Before legalize we can introduce too wide illegal loads which will be later
+ // split into legal sized loads. This enables us to combine i64 load by i8
+ // patterns to a couple of i32 loads on 32 bit targets.
+ if (LegalOperations && !TLI.isOperationLegal(ISD::LOAD, VT))
+ return SDValue();
+
+ std::function<unsigned(unsigned, unsigned)> LittleEndianByteAt = [](
+ unsigned BW, unsigned i) { return i; };
+ std::function<unsigned(unsigned, unsigned)> BigEndianByteAt = [](
+ unsigned BW, unsigned i) { return BW - i - 1; };
+
+ bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
+ auto MemoryByteOffset = [&] (ByteProvider P) {
+ assert(P.isMemory() && "Must be a memory byte provider");
+ unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
+ assert(LoadBitWidth % 8 == 0 &&
+ "can only analyze providers for individual bytes not bit");
+ unsigned LoadByteWidth = LoadBitWidth / 8;
+ return IsBigEndianTarget
+ ? BigEndianByteAt(LoadByteWidth, P.ByteOffset)
+ : LittleEndianByteAt(LoadByteWidth, P.ByteOffset);
+ };
+
+ Optional<BaseIndexOffset> Base;
+ SDValue Chain;
+
+ SmallPtrSet<LoadSDNode *, 8> Loads;
+ Optional<ByteProvider> FirstByteProvider;
+ int64_t FirstOffset = INT64_MAX;
+
+ // Check if all the bytes of the OR we are looking at are loaded from the same
+ // base address. Collect bytes offsets from Base address in ByteOffsets.
+ SmallVector<int64_t, 4> ByteOffsets(ByteWidth);
+ for (unsigned i = 0; i < ByteWidth; i++) {
+ auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
+ if (!P || !P->isMemory()) // All the bytes must be loaded from memory
+ return SDValue();
+
+ LoadSDNode *L = P->Load;
+ assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() &&
+ "Must be enforced by calculateByteProvider");
+ assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
+
+ // All loads must share the same chain
+ SDValue LChain = L->getChain();
+ if (!Chain)
+ Chain = LChain;
+ else if (Chain != LChain)
+ return SDValue();
+
+ // Loads must share the same base address
+ BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
+ int64_t ByteOffsetFromBase = 0;
+ if (!Base)
+ Base = Ptr;
+ else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
+ return SDValue();
+
+ // Calculate the offset of the current byte from the base address
+ ByteOffsetFromBase += MemoryByteOffset(*P);
+ ByteOffsets[i] = ByteOffsetFromBase;
+
+ // Remember the first byte load
+ if (ByteOffsetFromBase < FirstOffset) {
+ FirstByteProvider = P;
+ FirstOffset = ByteOffsetFromBase;
+ }
+
+ Loads.insert(L);
+ }
+ assert(!Loads.empty() && "All the bytes of the value must be loaded from "
+ "memory, so there must be at least one load which produces the value");
+ assert(Base && "Base address of the accessed memory location must be set");
+ assert(FirstOffset != INT64_MAX && "First byte offset must be set");
+
+ // Check if the bytes of the OR we are looking at match with either big or
+ // little endian value load
+ bool BigEndian = true, LittleEndian = true;
+ for (unsigned i = 0; i < ByteWidth; i++) {
+ int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
+ LittleEndian &= CurrentByteOffset == LittleEndianByteAt(ByteWidth, i);
+ BigEndian &= CurrentByteOffset == BigEndianByteAt(ByteWidth, i);
+ if (!BigEndian && !LittleEndian)
+ return SDValue();
+ }
+ assert((BigEndian != LittleEndian) && "should be either or");
+ assert(FirstByteProvider && "must be set");
+
+ // Ensure that the first byte is loaded from zero offset of the first load.
+ // So the combined value can be loaded from the first load address.
+ if (MemoryByteOffset(*FirstByteProvider) != 0)
+ return SDValue();
+ LoadSDNode *FirstLoad = FirstByteProvider->Load;
+
+ // The node we are looking at matches with the pattern, check if we can
+ // replace it with a single load and bswap if needed.
+
+ // If the load needs byte swap check if the target supports it
+ bool NeedsBswap = IsBigEndianTarget != BigEndian;
+
+ // Before legalize we can introduce illegal bswaps which will be later
+ // converted to an explicit bswap sequence. This way we end up with a single
+ // load and byte shuffling instead of several loads and byte shuffling.
+ if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT))
+ return SDValue();
+
+ // Check that a load of the wide type is both allowed and fast on the target
+ bool Fast = false;
+ bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
+ VT, FirstLoad->getAddressSpace(),
+ FirstLoad->getAlignment(), &Fast);
+ if (!Allowed || !Fast)
+ return SDValue();
+
+ SDValue NewLoad =
+ DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(),
+ FirstLoad->getPointerInfo(), FirstLoad->getAlignment());
+
+ // Transfer chain users from old loads to the new load.
+ for (LoadSDNode *L : Loads)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
+
+ return NeedsBswap ? DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad) : NewLoad;
+}
+
+// If the target has andn, bsl, or a similar bit-select instruction,
+// we want to unfold masked merge, with canonical pattern of:
+// | A | |B|
+// ((x ^ y) & m) ^ y
+// | D |
+// Into:
+// (x & m) | (y & ~m)
+// If y is a constant, and the 'andn' does not work with immediates,
+// we unfold into a different pattern:
+// ~(~x & m) & (m | y)
+// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
+// the very least that breaks andnpd / andnps patterns, and because those
+// patterns are simplified in IR and shouldn't be created in the DAG
+SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
+ assert(N->getOpcode() == ISD::XOR);
+
+ // Don't touch 'not' (i.e. where y = -1).
+ if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+
+ // There are 3 commutable operators in the pattern,
+ // so we have to deal with 8 possible variants of the basic pattern.
+ SDValue X, Y, M;
+ auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
+ if (And.getOpcode() != ISD::AND || !And.hasOneUse())
+ return false;
+ SDValue Xor = And.getOperand(XorIdx);
+ if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
+ return false;
+ SDValue Xor0 = Xor.getOperand(0);
+ SDValue Xor1 = Xor.getOperand(1);
+ // Don't touch 'not' (i.e. where y = -1).
+ if (isAllOnesOrAllOnesSplat(Xor1))
+ return false;
+ if (Other == Xor0)
+ std::swap(Xor0, Xor1);
+ if (Other != Xor1)
+ return false;
+ X = Xor0;
+ Y = Xor1;
+ M = And.getOperand(XorIdx ? 0 : 1);
+ return true;
+ };
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
+ !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
+ return SDValue();
+
+ // Don't do anything if the mask is constant. This should not be reachable.
+ // InstCombine should have already unfolded this pattern, and DAGCombiner
+ // probably shouldn't produce it, too.
+ if (isa<ConstantSDNode>(M.getNode()))
+ return SDValue();
+
+ // We can transform if the target has AndNot
+ if (!TLI.hasAndNot(M))
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // If Y is a constant, check that 'andn' works with immediates.
+ if (!TLI.hasAndNot(Y)) {
+ assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
+ // If not, we need to do a bit more work to make sure andn is still used.
+ SDValue NotX = DAG.getNOT(DL, X, VT);
+ SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
+ SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
+ SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
+ return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
+ }
+
+ SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
+ SDValue NotM = DAG.getNOT(DL, M, VT);
+ SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
+
+ return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
+}
+
SDValue DAGCombiner::visitXOR(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -4093,112 +6135,138 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
}
// fold (xor undef, undef) -> 0. This is a common idiom (misuse).
- if (N0.getOpcode() == ISD::UNDEF && N1.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, SDLoc(N), VT);
+ SDLoc DL(N);
+ if (N0.isUndef() && N1.isUndef())
+ return DAG.getConstant(0, DL, VT);
// fold (xor x, undef) -> undef
- if (N0.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef())
return N0;
- if (N1.getOpcode() == ISD::UNDEF)
+ if (N1.isUndef())
return N1;
// fold (xor c1, c2) -> c1^c2
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
if (N0C && N1C)
- return DAG.FoldConstantArithmetic(ISD::XOR, SDLoc(N), VT, N0C, N1C);
+ return DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, N0C, N1C);
// canonicalize constant to RHS
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
- !isConstantIntBuildVectorOrConstantInt(N1))
- return DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0);
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+ !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+ return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
// fold (xor x, 0) -> x
if (isNullConstant(N1))
return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// reassociate xor
- if (SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1))
+ if (SDValue RXOR = ReassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
return RXOR;
// fold !(x cc y) -> (x !cc y)
+ unsigned N0Opcode = N0.getOpcode();
SDValue LHS, RHS, CC;
if (TLI.isConstTrueVal(N1.getNode()) && isSetCCEquivalent(N0, LHS, RHS, CC)) {
- bool isInt = LHS.getValueType().isInteger();
ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
- isInt);
-
+ LHS.getValueType().isInteger());
if (!LegalOperations ||
TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
- switch (N0.getOpcode()) {
+ switch (N0Opcode) {
default:
llvm_unreachable("Unhandled SetCC Equivalent!");
case ISD::SETCC:
- return DAG.getSetCC(SDLoc(N), VT, LHS, RHS, NotCC);
+ return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
case ISD::SELECT_CC:
- return DAG.getSelectCC(SDLoc(N), LHS, RHS, N0.getOperand(2),
+ return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
N0.getOperand(3), NotCC);
}
}
}
// fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
- if (isOneConstant(N1) && N0.getOpcode() == ISD::ZERO_EXTEND &&
- N0.getNode()->hasOneUse() &&
+ if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
SDValue V = N0.getOperand(0);
- SDLoc DL(N0);
- V = DAG.getNode(ISD::XOR, DL, V.getValueType(), V,
- DAG.getConstant(1, DL, V.getValueType()));
+ SDLoc DL0(N0);
+ V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
+ DAG.getConstant(1, DL0, V.getValueType()));
AddToWorklist(V.getNode());
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, V);
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
}
// fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
- if (isOneConstant(N1) && VT == MVT::i1 &&
- (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) {
+ if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
+ (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
if (isOneUseSetCC(RHS) || isOneUseSetCC(LHS)) {
- unsigned NewOpcode = N0.getOpcode() == ISD::AND ? ISD::OR : ISD::AND;
+ unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), VT, LHS, N1); // LHS = ~LHS
RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), VT, RHS, N1); // RHS = ~RHS
AddToWorklist(LHS.getNode()); AddToWorklist(RHS.getNode());
- return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS);
+ return DAG.getNode(NewOpcode, DL, VT, LHS, RHS);
}
}
// fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
- if (isAllOnesConstant(N1) &&
- (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) {
+ if (isAllOnesConstant(N1) && N0.hasOneUse() &&
+ (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
if (isa<ConstantSDNode>(RHS) || isa<ConstantSDNode>(LHS)) {
- unsigned NewOpcode = N0.getOpcode() == ISD::AND ? ISD::OR : ISD::AND;
+ unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), VT, LHS, N1); // LHS = ~LHS
RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), VT, RHS, N1); // RHS = ~RHS
AddToWorklist(LHS.getNode()); AddToWorklist(RHS.getNode());
- return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS);
+ return DAG.getNode(NewOpcode, DL, VT, LHS, RHS);
}
}
// fold (xor (and x, y), y) -> (and (not x), y)
- if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
- N0->getOperand(1) == N1) {
- SDValue X = N0->getOperand(0);
+ if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
+ SDValue X = N0.getOperand(0);
SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
AddToWorklist(NotX.getNode());
- return DAG.getNode(ISD::AND, SDLoc(N), VT, NotX, N1);
- }
- // fold (xor (xor x, c1), c2) -> (xor x, (xor c1, c2))
- if (N1C && N0.getOpcode() == ISD::XOR) {
- if (const ConstantSDNode *N00C = getAsNonOpaqueConstant(N0.getOperand(0))) {
- SDLoc DL(N);
- return DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
- DAG.getConstant(N1C->getAPIntValue() ^
- N00C->getAPIntValue(), DL, VT));
+ return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
+ }
+
+ if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
+ ConstantSDNode *XorC = isConstOrConstSplat(N1);
+ ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
+ unsigned BitWidth = VT.getScalarSizeInBits();
+ if (XorC && ShiftC) {
+ // Don't crash on an oversized shift. We can not guarantee that a bogus
+ // shift has been simplified to undef.
+ uint64_t ShiftAmt = ShiftC->getLimitedValue();
+ if (ShiftAmt < BitWidth) {
+ APInt Ones = APInt::getAllOnesValue(BitWidth);
+ Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
+ if (XorC->getAPIntValue() == Ones) {
+ // If the xor constant is a shifted -1, do a 'not' before the shift:
+ // xor (X << ShiftC), XorC --> (not X) << ShiftC
+ // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
+ SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
+ return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
+ }
+ }
}
- if (const ConstantSDNode *N01C = getAsNonOpaqueConstant(N0.getOperand(1))) {
- SDLoc DL(N);
- return DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
- DAG.getConstant(N1C->getAPIntValue() ^
- N01C->getAPIntValue(), DL, VT));
+ }
+
+ // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
+ if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
+ SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
+ SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
+ if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
+ SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
+ SDValue S0 = S.getOperand(0);
+ if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) {
+ unsigned OpSizeInBits = VT.getScalarSizeInBits();
+ if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
+ if (C->getAPIntValue() == (OpSizeInBits - 1))
+ return DAG.getNode(ISD::ABS, DL, VT, S0);
+ }
}
}
+
// fold (xor x, x) -> 0
if (N0 == N1)
- return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes);
+ return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
// fold (xor (shl 1, x), -1) -> (rotl ~1, x)
// Here is a concrete example of this equivalence:
@@ -4218,21 +6286,23 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
// consistent result.
// - Pushing the zero left requires shifting one bits in from the right.
// A rotate left of ~1 is a nice way of achieving the desired result.
- if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0.getOpcode() == ISD::SHL
- && isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
- SDLoc DL(N);
+ if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
+ isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
N0.getOperand(1));
}
// Simplify: xor (op x...), (op y...) -> (op (xor x, y))
- if (N0.getOpcode() == N1.getOpcode())
- if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
- return Tmp;
+ if (N0Opcode == N1.getOpcode())
+ if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
+ return V;
+
+ // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
+ if (SDValue MM = unfoldMaskedMerge(N))
+ return MM;
// Simplify the expression using non-local knowledge.
- if (!VT.isVector() &&
- SimplifyDemandedBits(SDValue(N, 0)))
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
return SDValue();
@@ -4241,6 +6311,10 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
/// Handle transforms common to the three shifts, when the shift amount is a
/// constant.
SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) {
+ // Do not turn a 'not' into a regular xor.
+ if (isBitwiseNot(N->getOperand(0)))
+ return SDValue();
+
SDNode *LHS = N->getOperand(0).getNode();
if (!LHS->hasOneUse()) return SDValue();
@@ -4270,16 +6344,20 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) {
ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS->getOperand(1));
if (!BinOpCst) return SDValue();
- // FIXME: disable this unless the input to the binop is a shift by a constant.
- // If it is not a shift, it pessimizes some common cases like:
- //
- // void foo(int *X, int i) { X[i & 1235] = 1; }
- // int bar(int *X, int i) { return X[i & 255]; }
+ // FIXME: disable this unless the input to the binop is a shift by a constant
+ // or is copy/select.Enable this in other cases when figure out it's exactly profitable.
SDNode *BinOpLHSVal = LHS->getOperand(0).getNode();
- if ((BinOpLHSVal->getOpcode() != ISD::SHL &&
- BinOpLHSVal->getOpcode() != ISD::SRA &&
- BinOpLHSVal->getOpcode() != ISD::SRL) ||
- !isa<ConstantSDNode>(BinOpLHSVal->getOperand(1)))
+ bool isShift = BinOpLHSVal->getOpcode() == ISD::SHL ||
+ BinOpLHSVal->getOpcode() == ISD::SRA ||
+ BinOpLHSVal->getOpcode() == ISD::SRL;
+ bool isCopyOrSelect = BinOpLHSVal->getOpcode() == ISD::CopyFromReg ||
+ BinOpLHSVal->getOpcode() == ISD::SELECT;
+
+ if ((!isShift || !isa<ConstantSDNode>(BinOpLHSVal->getOperand(1))) &&
+ !isCopyOrSelect)
+ return SDValue();
+
+ if (isCopyOrSelect && N->hasOneUse())
return SDValue();
EVT VT = N->getValueType(0);
@@ -4294,7 +6372,7 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) {
return SDValue();
}
- if (!TLI.isDesirableToCommuteWithShift(LHS))
+ if (!TLI.isDesirableToCommuteWithShift(N, Level))
return SDValue();
// Fold the constants, shifting the binop RHS by the shift amount.
@@ -4319,19 +6397,15 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
// (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
if (N->hasOneUse() && N->getOperand(0).hasOneUse()) {
SDValue N01 = N->getOperand(0).getOperand(1);
-
- if (ConstantSDNode *N01C = isConstOrConstSplat(N01)) {
- if (!N01C->isOpaque()) {
- EVT TruncVT = N->getValueType(0);
- SDValue N00 = N->getOperand(0).getOperand(0);
- APInt TruncC = N01C->getAPIntValue();
- TruncC = TruncC.trunc(TruncVT.getScalarSizeInBits());
- SDLoc DL(N);
-
- return DAG.getNode(ISD::AND, DL, TruncVT,
- DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00),
- DAG.getConstant(TruncC, DL, TruncVT));
- }
+ if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
+ SDLoc DL(N);
+ EVT TruncVT = N->getValueType(0);
+ SDValue N00 = N->getOperand(0).getOperand(0);
+ SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
+ SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
+ AddToWorklist(Trunc00.getNode());
+ AddToWorklist(Trunc01.getNode());
+ return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
}
}
@@ -4339,13 +6413,58 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
}
SDValue DAGCombiner::visitRotate(SDNode *N) {
+ SDLoc dl(N);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ unsigned Bitsize = VT.getScalarSizeInBits();
+
+ // fold (rot x, 0) -> x
+ if (isNullOrNullSplat(N1))
+ return N0;
+
+ // fold (rot x, c) -> x iff (c % BitSize) == 0
+ if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
+ APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
+ if (DAG.MaskedValueIsZero(N1, ModuloMask))
+ return N0;
+ }
+
+ // fold (rot x, c) -> (rot x, c % BitSize)
+ if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) {
+ if (Cst->getAPIntValue().uge(Bitsize)) {
+ uint64_t RotAmt = Cst->getAPIntValue().urem(Bitsize);
+ return DAG.getNode(N->getOpcode(), dl, VT, N0,
+ DAG.getConstant(RotAmt, dl, N1.getValueType()));
+ }
+ }
+
// fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
- if (N->getOperand(1).getOpcode() == ISD::TRUNCATE &&
- N->getOperand(1).getOperand(0).getOpcode() == ISD::AND) {
- SDValue NewOp1 = distributeTruncateThroughAnd(N->getOperand(1).getNode());
- if (NewOp1.getNode())
- return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
- N->getOperand(0), NewOp1);
+ if (N1.getOpcode() == ISD::TRUNCATE &&
+ N1.getOperand(0).getOpcode() == ISD::AND) {
+ if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
+ return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
+ }
+
+ unsigned NextOp = N0.getOpcode();
+ // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize)
+ if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
+ SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
+ SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
+ if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
+ EVT ShiftVT = C1->getValueType(0);
+ bool SameSide = (N->getOpcode() == NextOp);
+ unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
+ if (SDValue CombinedShift =
+ DAG.FoldConstantArithmetic(CombineOp, dl, ShiftVT, C1, C2)) {
+ SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
+ SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
+ ISD::SREM, dl, ShiftVT, CombinedShift.getNode(),
+ BitsizeC.getNode());
+ return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
+ CombinedShiftNorm);
+ }
+ }
}
return SDValue();
}
@@ -4353,11 +6472,13 @@ SDValue DAGCombiner::visitRotate(SDNode *N) {
SDValue DAGCombiner::visitSHL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
+ if (SDValue V = DAG.simplifyShift(N0, N1))
+ return V;
+
EVT VT = N0.getValueType();
unsigned OpSizeInBits = VT.getScalarSizeInBits();
// fold vector ops
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
if (VT.isVector()) {
if (SDValue FoldedVOp = SimplifyVBinOp(N))
return FoldedVOp;
@@ -4378,28 +6499,20 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
N01CV, N1CV))
return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
}
- } else {
- N1C = isConstOrConstSplat(N1);
}
}
}
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
+
// fold (shl c1, c2) -> c1<<c2
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C);
- // fold (shl 0, x) -> 0
- if (isNullConstant(N0))
- return N0;
- // fold (shl x, c >= size(x)) -> undef
- if (N1C && N1C->getAPIntValue().uge(OpSizeInBits))
- return DAG.getUNDEF(VT);
- // fold (shl x, 0) -> x
- if (N1C && N1C->isNullValue())
- return N0;
- // fold (shl undef, x) -> 0
- if (N0.getOpcode() == ISD::UNDEF)
- return DAG.getConstant(0, SDLoc(N), VT);
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// if (shl x, c) is known to be zero, return 0
if (DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(OpSizeInBits)))
@@ -4407,8 +6520,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
if (N1.getOpcode() == ISD::TRUNCATE &&
N1.getOperand(0).getOpcode() == ISD::AND) {
- SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode());
- if (NewOp1.getNode())
+ if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
}
@@ -4416,15 +6528,29 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
return SDValue(N, 0);
// fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
- if (N1C && N0.getOpcode() == ISD::SHL) {
- if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
- uint64_t c1 = N0C1->getZExtValue();
- uint64_t c2 = N1C->getZExtValue();
+ if (N0.getOpcode() == ISD::SHL) {
+ auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
+ ConstantSDNode *RHS) {
+ APInt c1 = LHS->getAPIntValue();
+ APInt c2 = RHS->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ return (c1 + c2).uge(OpSizeInBits);
+ };
+ if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
+ return DAG.getConstant(0, SDLoc(N), VT);
+
+ auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
+ ConstantSDNode *RHS) {
+ APInt c1 = LHS->getAPIntValue();
+ APInt c2 = RHS->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ return (c1 + c2).ult(OpSizeInBits);
+ };
+ if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
SDLoc DL(N);
- if (c1 + c2 >= OpSizeInBits)
- return DAG.getConstant(0, DL, VT);
- return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
- DAG.getConstant(c1 + c2, DL, N1.getValueType()));
+ EVT ShiftVT = N1.getValueType();
+ SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
+ return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
}
}
@@ -4439,18 +6565,22 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
N0.getOperand(0).getOpcode() == ISD::SHL) {
SDValue N0Op0 = N0.getOperand(0);
if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) {
- uint64_t c1 = N0Op0C1->getZExtValue();
- uint64_t c2 = N1C->getZExtValue();
+ APInt c1 = N0Op0C1->getAPIntValue();
+ APInt c2 = N1C->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+
EVT InnerShiftVT = N0Op0.getValueType();
uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
- if (c2 >= OpSizeInBits - InnerShiftSize) {
+ if (c2.uge(OpSizeInBits - InnerShiftSize)) {
SDLoc DL(N0);
- if (c1 + c2 >= OpSizeInBits)
+ APInt Sum = c1 + c2;
+ if (Sum.uge(OpSizeInBits))
return DAG.getConstant(0, DL, VT);
- return DAG.getNode(ISD::SHL, DL, VT,
- DAG.getNode(N0.getOpcode(), DL, VT,
- N0Op0->getOperand(0)),
- DAG.getConstant(c1 + c2, DL, N1.getValueType()));
+
+ return DAG.getNode(
+ ISD::SHL, DL, VT,
+ DAG.getNode(N0.getOpcode(), DL, VT, N0Op0->getOperand(0)),
+ DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
}
}
}
@@ -4462,8 +6592,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
N0.getOperand(0).getOpcode() == ISD::SRL) {
SDValue N0Op0 = N0.getOperand(0);
if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) {
- uint64_t c1 = N0Op0C1->getZExtValue();
- if (c1 < VT.getScalarSizeInBits()) {
+ if (N0Op0C1->getAPIntValue().ult(VT.getScalarSizeInBits())) {
+ uint64_t c1 = N0Op0C1->getZExtValue();
uint64_t c2 = N1C->getZExtValue();
if (c1 == c2) {
SDValue NewOp0 = N0.getOperand(0);
@@ -4482,7 +6612,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
// fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 > C2
if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
- cast<BinaryWithFlagsSDNode>(N0)->Flags.hasExact()) {
+ N0->getFlags().hasExact()) {
if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
uint64_t C1 = N0C1->getZExtValue();
uint64_t C2 = N1C->getZExtValue();
@@ -4499,7 +6629,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// (and (srl x, (sub c1, c2), MASK)
// Only fold this if the inner shift has no other uses -- if it does, folding
// this will increase the total number of instructions.
- if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
+ if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() &&
+ TLI.shouldFoldShiftPairToMask(N, Level)) {
if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
uint64_t c1 = N0C1->getZExtValue();
if (c1 < OpSizeInBits) {
@@ -4507,12 +6638,12 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1);
SDValue Shift;
if (c2 > c1) {
- Mask = Mask.shl(c2 - c1);
+ Mask <<= c2 - c1;
SDLoc DL(N);
Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
DAG.getConstant(c2 - c1, DL, N1.getValueType()));
} else {
- Mask = Mask.lshr(c1 - c2);
+ Mask.lshrInPlace(c1 - c2);
SDLoc DL(N);
Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
DAG.getConstant(c1 - c2, DL, N1.getValueType()));
@@ -4523,37 +6654,39 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
}
}
}
+
// fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
- if (N1C && N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1)) {
- unsigned BitSize = VT.getScalarSizeInBits();
+ if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
+ isConstantOrConstantVector(N1, /* No Opaques */ true)) {
SDLoc DL(N);
- SDValue HiBitsMask =
- DAG.getConstant(APInt::getHighBitsSet(BitSize,
- BitSize - N1C->getZExtValue()),
- DL, VT);
- return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0),
- HiBitsMask);
+ SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
+ SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
+ return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
}
// fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
+ // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
// Variant of version done on multiply, except mul by a power of 2 is turned
// into a shift.
- APInt Val;
- if (N1C && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() &&
- (isa<ConstantSDNode>(N0.getOperand(1)) ||
- isConstantSplatVector(N0.getOperand(1).getNode(), Val))) {
+ if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
+ N0.getNode()->hasOneUse() &&
+ isConstantOrConstantVector(N1, /* No Opaques */ true) &&
+ isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
+ TLI.isDesirableToCommuteWithShift(N, Level)) {
SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
- return DAG.getNode(ISD::ADD, SDLoc(N), VT, Shl0, Shl1);
+ AddToWorklist(Shl0.getNode());
+ AddToWorklist(Shl1.getNode());
+ return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
}
// fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
- if (N1C && N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse()) {
- if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
- if (SDValue Folded =
- DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, N0C1, N1C))
- return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Folded);
- }
+ if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() &&
+ isConstantOrConstantVector(N1, /* No Opaques */ true) &&
+ isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) {
+ SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
+ if (isConstantOrConstantVector(Shl))
+ return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
}
if (N1C && !N1C->isOpaque())
@@ -4566,34 +6699,33 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
SDValue DAGCombiner::visitSRA(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
+ if (SDValue V = DAG.simplifyShift(N0, N1))
+ return V;
+
EVT VT = N0.getValueType();
- unsigned OpSizeInBits = VT.getScalarType().getSizeInBits();
+ unsigned OpSizeInBits = VT.getScalarSizeInBits();
+
+ // Arithmetic shifting an all-sign-bit value is a no-op.
+ // fold (sra 0, x) -> 0
+ // fold (sra -1, x) -> -1
+ if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
+ return N0;
// fold vector ops
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
- if (VT.isVector()) {
+ if (VT.isVector())
if (SDValue FoldedVOp = SimplifyVBinOp(N))
return FoldedVOp;
- N1C = isConstOrConstSplat(N1);
- }
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
// fold (sra c1, c2) -> (sra c1, c2)
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C);
- // fold (sra 0, x) -> 0
- if (isNullConstant(N0))
- return N0;
- // fold (sra -1, x) -> -1
- if (isAllOnesConstant(N0))
- return N0;
- // fold (sra x, (setge c, size(x))) -> undef
- if (N1C && N1C->getZExtValue() >= OpSizeInBits)
- return DAG.getUNDEF(VT);
- // fold (sra x, 0) -> x
- if (N1C && N1C->isNullValue())
- return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
// sext_inreg.
if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
@@ -4609,14 +6741,30 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
}
// fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
- if (N1C && N0.getOpcode() == ISD::SRA) {
- if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) {
- unsigned Sum = N1C->getZExtValue() + C1->getZExtValue();
- if (Sum >= OpSizeInBits)
- Sum = OpSizeInBits - 1;
- SDLoc DL(N);
- return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
- DAG.getConstant(Sum, DL, N1.getValueType()));
+ // clamp (add c1, c2) to max shift.
+ if (N0.getOpcode() == ISD::SRA) {
+ SDLoc DL(N);
+ EVT ShiftVT = N1.getValueType();
+ EVT ShiftSVT = ShiftVT.getScalarType();
+ SmallVector<SDValue, 16> ShiftValues;
+
+ auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
+ APInt c1 = LHS->getAPIntValue();
+ APInt c2 = RHS->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ APInt Sum = c1 + c2;
+ unsigned ShiftSum =
+ Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
+ ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
+ return true;
+ };
+ if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
+ SDValue ShiftValue;
+ if (VT.isVector())
+ ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
+ else
+ ShiftValue = ShiftValues[0];
+ return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
}
}
@@ -4637,7 +6785,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements());
// Determine the residual right-shift amount.
- signed ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
+ int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
// If the shift is not a no-op (in which case this should be just a sign
// extend already), the truncated to type is legal, sign_extend is legal
@@ -4647,7 +6795,6 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
TLI.isTruncateFree(VT, TruncVT)) {
-
SDLoc DL(N);
SDValue Amt = DAG.getConstant(ShiftAmt, DL,
getShiftAmountTy(N0.getOperand(0).getValueType()));
@@ -4664,8 +6811,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
// fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
if (N1.getOpcode() == ISD::TRUNCATE &&
N1.getOperand(0).getOpcode() == ISD::AND) {
- SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode());
- if (NewOp1.getNode())
+ if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
}
@@ -4698,7 +6844,6 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
-
// If the sign bit is known to be zero, switch this to a SRL.
if (DAG.SignBitIsZero(N0))
return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
@@ -4713,81 +6858,90 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
SDValue DAGCombiner::visitSRL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
+ if (SDValue V = DAG.simplifyShift(N0, N1))
+ return V;
+
EVT VT = N0.getValueType();
- unsigned OpSizeInBits = VT.getScalarType().getSizeInBits();
+ unsigned OpSizeInBits = VT.getScalarSizeInBits();
// fold vector ops
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
- if (VT.isVector()) {
+ if (VT.isVector())
if (SDValue FoldedVOp = SimplifyVBinOp(N))
return FoldedVOp;
- N1C = isConstOrConstSplat(N1);
- }
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
// fold (srl c1, c2) -> c1 >>u c2
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C);
- // fold (srl 0, x) -> 0
- if (isNullConstant(N0))
- return N0;
- // fold (srl x, c >= size(x)) -> undef
- if (N1C && N1C->getZExtValue() >= OpSizeInBits)
- return DAG.getUNDEF(VT);
- // fold (srl x, 0) -> x
- if (N1C && N1C->isNullValue())
- return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// if (srl x, c) is known to be zero, return 0
if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(OpSizeInBits)))
return DAG.getConstant(0, SDLoc(N), VT);
// fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
- if (N1C && N0.getOpcode() == ISD::SRL) {
- if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) {
- uint64_t c1 = N01C->getZExtValue();
- uint64_t c2 = N1C->getZExtValue();
+ if (N0.getOpcode() == ISD::SRL) {
+ auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
+ ConstantSDNode *RHS) {
+ APInt c1 = LHS->getAPIntValue();
+ APInt c2 = RHS->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ return (c1 + c2).uge(OpSizeInBits);
+ };
+ if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
+ return DAG.getConstant(0, SDLoc(N), VT);
+
+ auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
+ ConstantSDNode *RHS) {
+ APInt c1 = LHS->getAPIntValue();
+ APInt c2 = RHS->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ return (c1 + c2).ult(OpSizeInBits);
+ };
+ if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
SDLoc DL(N);
- if (c1 + c2 >= OpSizeInBits)
- return DAG.getConstant(0, DL, VT);
- return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
- DAG.getConstant(c1 + c2, DL, N1.getValueType()));
+ EVT ShiftVT = N1.getValueType();
+ SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
+ return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
}
}
// fold (srl (trunc (srl x, c1)), c2) -> 0 or (trunc (srl x, (add c1, c2)))
if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
- N0.getOperand(0).getOpcode() == ISD::SRL &&
- isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) {
- uint64_t c1 =
- cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue();
- uint64_t c2 = N1C->getZExtValue();
- EVT InnerShiftVT = N0.getOperand(0).getValueType();
- EVT ShiftCountVT = N0.getOperand(0)->getOperand(1).getValueType();
- uint64_t InnerShiftSize = InnerShiftVT.getScalarType().getSizeInBits();
- // This is only valid if the OpSizeInBits + c1 = size of inner shift.
- if (c1 + OpSizeInBits == InnerShiftSize) {
- SDLoc DL(N0);
- if (c1 + c2 >= InnerShiftSize)
- return DAG.getConstant(0, DL, VT);
- return DAG.getNode(ISD::TRUNCATE, DL, VT,
- DAG.getNode(ISD::SRL, DL, InnerShiftVT,
- N0.getOperand(0)->getOperand(0),
- DAG.getConstant(c1 + c2, DL,
- ShiftCountVT)));
+ N0.getOperand(0).getOpcode() == ISD::SRL) {
+ if (auto N001C = isConstOrConstSplat(N0.getOperand(0).getOperand(1))) {
+ uint64_t c1 = N001C->getZExtValue();
+ uint64_t c2 = N1C->getZExtValue();
+ EVT InnerShiftVT = N0.getOperand(0).getValueType();
+ EVT ShiftCountVT = N0.getOperand(0).getOperand(1).getValueType();
+ uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
+ // This is only valid if the OpSizeInBits + c1 = size of inner shift.
+ if (c1 + OpSizeInBits == InnerShiftSize) {
+ SDLoc DL(N0);
+ if (c1 + c2 >= InnerShiftSize)
+ return DAG.getConstant(0, DL, VT);
+ return DAG.getNode(ISD::TRUNCATE, DL, VT,
+ DAG.getNode(ISD::SRL, DL, InnerShiftVT,
+ N0.getOperand(0).getOperand(0),
+ DAG.getConstant(c1 + c2, DL,
+ ShiftCountVT)));
+ }
}
}
// fold (srl (shl x, c), c) -> (and x, cst2)
- if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1) {
- unsigned BitSize = N0.getScalarValueSizeInBits();
- if (BitSize <= 64) {
- uint64_t ShAmt = N1C->getZExtValue() + 64 - BitSize;
- SDLoc DL(N);
- return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0),
- DAG.getConstant(~0ULL >> ShAmt, DL, VT));
- }
+ if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
+ isConstantOrConstantVector(N1, /* NoOpaques */ true)) {
+ SDLoc DL(N);
+ SDValue Mask =
+ DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1);
+ AddToWorklist(Mask.getNode());
+ return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask);
}
// fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
@@ -4806,7 +6960,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
DAG.getConstant(ShiftAmt, DL0,
getShiftAmountTy(SmallVT)));
AddToWorklist(SmallShift.getNode());
- APInt Mask = APInt::getAllOnesValue(OpSizeInBits).lshr(ShiftAmt);
+ APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
SDLoc DL(N);
return DAG.getNode(ISD::AND, DL, VT,
DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
@@ -4824,20 +6978,19 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit).
if (N1C && N0.getOpcode() == ISD::CTLZ &&
N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
- APInt KnownZero, KnownOne;
- DAG.computeKnownBits(N0.getOperand(0), KnownZero, KnownOne);
+ KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
// If any of the input bits are KnownOne, then the input couldn't be all
// zeros, thus the result of the srl will always be zero.
- if (KnownOne.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
+ if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
// If all of the bits input the to ctlz node are known to be zero, then
// the result of the ctlz is "32" and the result of the shift is one.
- APInt UnknownBits = ~KnownZero;
+ APInt UnknownBits = ~Known.Zero;
if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
// Otherwise, check to see if there is exactly one bit input to the ctlz.
- if ((UnknownBits & (UnknownBits - 1)) == 0) {
+ if (UnknownBits.isPowerOf2()) {
// Okay, we know that only that the single bit specified by UnknownBits
// could be set on input to the CTLZ node. If this bit is set, the SRL
// will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
@@ -4911,12 +7064,63 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ bool IsFSHL = N->getOpcode() == ISD::FSHL;
+ unsigned BitWidth = VT.getScalarSizeInBits();
+
+ // fold (fshl N0, N1, 0) -> N0
+ // fold (fshr N0, N1, 0) -> N1
+ if (isPowerOf2_32(BitWidth))
+ if (DAG.MaskedValueIsZero(
+ N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
+ return IsFSHL ? N0 : N1;
+
+ // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
+ if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
+ if (Cst->getAPIntValue().uge(BitWidth)) {
+ uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
+ return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
+ DAG.getConstant(RotAmt, SDLoc(N), N2.getValueType()));
+ }
+ }
+
+ // fold (fshl N0, N0, N2) -> (rotl N0, N2)
+ // fold (fshr N0, N0, N2) -> (rotr N0, N2)
+ // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
+ // is legal as well we might be better off avoiding non-constant (BW - N2).
+ unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
+ if (N0 == N1 && hasOperation(RotOpc, VT))
+ return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitABS(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+
+ // fold (abs c1) -> c2
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
+ return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
+ // fold (abs (abs x)) -> (abs x)
+ if (N0.getOpcode() == ISD::ABS)
+ return N0;
+ // fold (abs x) -> x iff not-negative
+ if (DAG.SignBitIsZero(N0))
+ return N0;
+ return SDValue();
+}
+
SDValue DAGCombiner::visitBSWAP(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
// fold (bswap c1) -> c2
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0);
// fold (bswap (bswap x)) -> x
if (N0.getOpcode() == ISD::BSWAP)
@@ -4924,13 +7128,33 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+
+ // fold (bitreverse c1) -> c2
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
+ return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
+ // fold (bitreverse (bitreverse x)) -> x
+ if (N0.getOpcode() == ISD::BITREVERSE)
+ return N0.getOperand(0);
+ return SDValue();
+}
+
SDValue DAGCombiner::visitCTLZ(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
// fold (ctlz c1) -> c2
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
+
+ // If the value is known never to be zero, switch to the undef version.
+ if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
+ if (DAG.isKnownNeverZero(N0))
+ return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
+ }
+
return SDValue();
}
@@ -4939,7 +7163,7 @@ SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
EVT VT = N->getValueType(0);
// fold (ctlz_zero_undef c1) -> c2
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
return SDValue();
}
@@ -4949,8 +7173,15 @@ SDValue DAGCombiner::visitCTTZ(SDNode *N) {
EVT VT = N->getValueType(0);
// fold (cttz c1) -> c2
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
+
+ // If the value is known never to be zero, switch to the undef version.
+ if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
+ if (DAG.isKnownNeverZero(N0))
+ return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
+ }
+
return SDValue();
}
@@ -4959,7 +7190,7 @@ SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
EVT VT = N->getValueType(0);
// fold (cttz_zero_undef c1) -> c2
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
return SDValue();
}
@@ -4969,20 +7200,30 @@ SDValue DAGCombiner::visitCTPOP(SDNode *N) {
EVT VT = N->getValueType(0);
// fold (ctpop c1) -> c2
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
return SDValue();
}
+// FIXME: This should be checking for no signed zeros on individual operands, as
+// well as no nans.
+static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS, SDValue RHS) {
+ const TargetOptions &Options = DAG.getTarget().Options;
+ EVT VT = LHS.getValueType();
+
+ return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
+ DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
+}
-/// \brief Generate Min/Max node
-static SDValue combineMinNumMaxNum(SDLoc DL, EVT VT, SDValue LHS, SDValue RHS,
- SDValue True, SDValue False,
+/// Generate Min/Max node
+static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
+ SDValue RHS, SDValue True, SDValue False,
ISD::CondCode CC, const TargetLowering &TLI,
SelectionDAG &DAG) {
if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
return SDValue();
+ EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
switch (CC) {
case ISD::SETOLT:
case ISD::SETOLE:
@@ -4990,8 +7231,15 @@ static SDValue combineMinNumMaxNum(SDLoc DL, EVT VT, SDValue LHS, SDValue RHS,
case ISD::SETLE:
case ISD::SETULT:
case ISD::SETULE: {
+ // Since it's known never nan to get here already, either fminnum or
+ // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
+ // expanded in terms of it.
+ unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
+ if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
+ return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
+
unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
- if (TLI.isOperationLegal(Opcode, VT))
+ if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
return DAG.getNode(Opcode, DL, VT, LHS, RHS);
return SDValue();
}
@@ -5001,8 +7249,12 @@ static SDValue combineMinNumMaxNum(SDLoc DL, EVT VT, SDValue LHS, SDValue RHS,
case ISD::SETGE:
case ISD::SETUGT:
case ISD::SETUGE: {
+ unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
+ if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
+ return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
+
unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
- if (TLI.isOperationLegal(Opcode, VT))
+ if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
return DAG.getNode(Opcode, DL, VT, LHS, RHS);
return SDValue();
}
@@ -5011,25 +7263,76 @@ static SDValue combineMinNumMaxNum(SDLoc DL, EVT VT, SDValue LHS, SDValue RHS,
}
}
-SDValue DAGCombiner::visitSELECT(SDNode *N) {
- SDValue N0 = N->getOperand(0);
+SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
+ SDValue Cond = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
- EVT VT0 = N0.getValueType();
+ EVT CondVT = Cond.getValueType();
+ SDLoc DL(N);
- // fold (select C, X, X) -> X
- if (N1 == N2)
- return N1;
- if (const ConstantSDNode *N0C = dyn_cast<const ConstantSDNode>(N0)) {
- // fold (select true, X, Y) -> X
- // fold (select false, X, Y) -> Y
- return !N0C->isNullValue() ? N1 : N2;
- }
- // fold (select C, 1, X) -> (or C, X)
- if (VT == MVT::i1 && isOneConstant(N1))
- return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N2);
- // fold (select C, 0, 1) -> (xor C, 1)
+ if (!VT.isInteger())
+ return SDValue();
+
+ auto *C1 = dyn_cast<ConstantSDNode>(N1);
+ auto *C2 = dyn_cast<ConstantSDNode>(N2);
+ if (!C1 || !C2)
+ return SDValue();
+
+ // Only do this before legalization to avoid conflicting with target-specific
+ // transforms in the other direction (create a select from a zext/sext). There
+ // is also a target-independent combine here in DAGCombiner in the other
+ // direction for (select Cond, -1, 0) when the condition is not i1.
+ if (CondVT == MVT::i1 && !LegalOperations) {
+ if (C1->isNullValue() && C2->isOne()) {
+ // select Cond, 0, 1 --> zext (!Cond)
+ SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
+ if (VT != MVT::i1)
+ NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
+ return NotCond;
+ }
+ if (C1->isNullValue() && C2->isAllOnesValue()) {
+ // select Cond, 0, -1 --> sext (!Cond)
+ SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
+ if (VT != MVT::i1)
+ NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
+ return NotCond;
+ }
+ if (C1->isOne() && C2->isNullValue()) {
+ // select Cond, 1, 0 --> zext (Cond)
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
+ return Cond;
+ }
+ if (C1->isAllOnesValue() && C2->isNullValue()) {
+ // select Cond, -1, 0 --> sext (Cond)
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
+ return Cond;
+ }
+
+ // For any constants that differ by 1, we can transform the select into an
+ // extend and add. Use a target hook because some targets may prefer to
+ // transform in the other direction.
+ if (TLI.convertSelectOfConstantsToMath(VT)) {
+ if (C1->getAPIntValue() - 1 == C2->getAPIntValue()) {
+ // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
+ return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
+ }
+ if (C1->getAPIntValue() + 1 == C2->getAPIntValue()) {
+ // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
+ return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
+ }
+ }
+
+ return SDValue();
+ }
+
+ // fold (select Cond, 0, 1) -> (xor Cond, 1)
// We can't do this reliably if integer based booleans have different contents
// to floating point based booleans. This is because we can't tell whether we
// have an integer-based boolean or a floating-point-based boolean unless we
@@ -5038,85 +7341,92 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
// undiscoverable (or not reasonably discoverable). For example, it could be
// in another basic block or it could require searching a complicated
// expression.
- if (VT.isInteger() &&
- (VT0 == MVT::i1 || (VT0.isInteger() &&
- TLI.getBooleanContents(false, false) ==
- TLI.getBooleanContents(false, true) &&
- TLI.getBooleanContents(false, false) ==
- TargetLowering::ZeroOrOneBooleanContent)) &&
- isNullConstant(N1) && isOneConstant(N2)) {
- SDValue XORNode;
- if (VT == VT0) {
- SDLoc DL(N);
- return DAG.getNode(ISD::XOR, DL, VT0,
- N0, DAG.getConstant(1, DL, VT0));
- }
- SDLoc DL0(N0);
- XORNode = DAG.getNode(ISD::XOR, DL0, VT0,
- N0, DAG.getConstant(1, DL0, VT0));
- AddToWorklist(XORNode.getNode());
- if (VT.bitsGT(VT0))
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, XORNode);
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, XORNode);
+ if (CondVT.isInteger() &&
+ TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
+ TargetLowering::ZeroOrOneBooleanContent &&
+ TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
+ TargetLowering::ZeroOrOneBooleanContent &&
+ C1->isNullValue() && C2->isOne()) {
+ SDValue NotCond =
+ DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
+ if (VT.bitsEq(CondVT))
+ return NotCond;
+ return DAG.getZExtOrTrunc(NotCond, DL, VT);
}
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitSELECT(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ EVT VT0 = N0.getValueType();
+ SDLoc DL(N);
+
+ if (SDValue V = DAG.simplifySelect(N0, N1, N2))
+ return V;
+
+ // fold (select X, X, Y) -> (or X, Y)
+ // fold (select X, 1, Y) -> (or C, Y)
+ if (VT == VT0 && VT == MVT::i1 && (N0 == N1 || isOneConstant(N1)))
+ return DAG.getNode(ISD::OR, DL, VT, N0, N2);
+
+ if (SDValue V = foldSelectOfConstants(N))
+ return V;
+
// fold (select C, 0, X) -> (and (not C), X)
if (VT == VT0 && VT == MVT::i1 && isNullConstant(N1)) {
SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
AddToWorklist(NOTNode.getNode());
- return DAG.getNode(ISD::AND, SDLoc(N), VT, NOTNode, N2);
+ return DAG.getNode(ISD::AND, DL, VT, NOTNode, N2);
}
// fold (select C, X, 1) -> (or (not C), X)
if (VT == VT0 && VT == MVT::i1 && isOneConstant(N2)) {
SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
AddToWorklist(NOTNode.getNode());
- return DAG.getNode(ISD::OR, SDLoc(N), VT, NOTNode, N1);
+ return DAG.getNode(ISD::OR, DL, VT, NOTNode, N1);
}
- // fold (select C, X, 0) -> (and C, X)
- if (VT == MVT::i1 && isNullConstant(N2))
- return DAG.getNode(ISD::AND, SDLoc(N), VT, N0, N1);
- // fold (select X, X, Y) -> (or X, Y)
- // fold (select X, 1, Y) -> (or X, Y)
- if (VT == MVT::i1 && (N0 == N1 || isOneConstant(N1)))
- return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N2);
// fold (select X, Y, X) -> (and X, Y)
// fold (select X, Y, 0) -> (and X, Y)
- if (VT == MVT::i1 && (N0 == N2 || isNullConstant(N2)))
- return DAG.getNode(ISD::AND, SDLoc(N), VT, N0, N1);
+ if (VT == VT0 && VT == MVT::i1 && (N0 == N2 || isNullConstant(N2)))
+ return DAG.getNode(ISD::AND, DL, VT, N0, N1);
// If we can fold this based on the true/false value, do so.
if (SimplifySelectOps(N, N1, N2))
- return SDValue(N, 0); // Don't revisit N.
+ return SDValue(N, 0); // Don't revisit N.
if (VT0 == MVT::i1) {
// The code in this block deals with the following 2 equivalences:
// select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
// select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
- // The target can specify its prefered form with the
+ // The target can specify its preferred form with the
// shouldNormalizeToSelectSequence() callback. However we always transform
// to the right anyway if we find the inner select exists in the DAG anyway
// and we always transform to the left side if we know that we can further
// optimize the combination of the conditions.
- bool normalizeToSequence
- = TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
+ bool normalizeToSequence =
+ TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
// select (and Cond0, Cond1), X, Y
// -> select Cond0, (select Cond1, X, Y), Y
if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
SDValue Cond0 = N0->getOperand(0);
SDValue Cond1 = N0->getOperand(1);
- SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N),
- N1.getValueType(), Cond1, N1, N2);
+ SDValue InnerSelect =
+ DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2);
if (normalizeToSequence || !InnerSelect.use_empty())
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0,
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
InnerSelect, N2);
}
// select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
SDValue Cond0 = N0->getOperand(0);
SDValue Cond1 = N0->getOperand(1);
- SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N),
- N1.getValueType(), Cond1, N1, N2);
+ SDValue InnerSelect =
+ DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2);
if (normalizeToSequence || !InnerSelect.use_empty())
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, N1,
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
InnerSelect);
}
@@ -5128,15 +7438,13 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
// Create the actual and node if we can generate good code for it.
if (!normalizeToSequence) {
- SDValue And = DAG.getNode(ISD::AND, SDLoc(N), N0.getValueType(),
- N0, N1_0);
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), And,
- N1_1, N2);
+ SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1, N2);
}
// Otherwise see if we can optimize the "and" to a better pattern.
if (SDValue Combined = visitANDLike(N0, N1_0, N))
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Combined,
- N1_1, N2);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
+ N2);
}
}
// select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
@@ -5147,49 +7455,72 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
// Create the actual or node if we can generate good code for it.
if (!normalizeToSequence) {
- SDValue Or = DAG.getNode(ISD::OR, SDLoc(N), N0.getValueType(),
- N0, N2_0);
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Or,
- N1, N2_2);
+ SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1, N2_2);
}
// Otherwise see if we can optimize to a better pattern.
if (SDValue Combined = visitORLike(N0, N2_0, N))
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Combined,
- N1, N2_2);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
+ N2_2);
}
}
}
- // fold selects based on a setcc into other things, such as min/max/abs
- if (N0.getOpcode() == ISD::SETCC) {
- // select x, y (fcmp lt x, y) -> fminnum x, y
- // select x, y (fcmp gt x, y) -> fmaxnum x, y
- //
- // This is OK if we don't care about what happens if either operand is a
- // NaN.
- //
+ if (VT0 == MVT::i1) {
+ // select (not Cond), N1, N2 -> select Cond, N2, N1
+ if (isBitwiseNot(N0))
+ return DAG.getNode(ISD::SELECT, DL, VT, N0->getOperand(0), N2, N1);
+ }
- // FIXME: Instead of testing for UnsafeFPMath, this should be checking for
- // no signed zeros as well as no nans.
- const TargetOptions &Options = DAG.getTarget().Options;
- if (Options.UnsafeFPMath &&
- VT.isFloatingPoint() && N0.hasOneUse() &&
- DAG.isKnownNeverNaN(N1) && DAG.isKnownNeverNaN(N2)) {
- ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
+ // Fold selects based on a setcc into other things, such as min/max/abs.
+ if (N0.getOpcode() == ISD::SETCC) {
+ SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
+ ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
- if (SDValue FMinMax = combineMinNumMaxNum(SDLoc(N), VT, N0.getOperand(0),
- N0.getOperand(1), N1, N2, CC,
- TLI, DAG))
+ // select (fcmp lt x, y), x, y -> fminnum x, y
+ // select (fcmp gt x, y), x, y -> fmaxnum x, y
+ //
+ // This is OK if we don't care what happens if either operand is a NaN.
+ if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2))
+ if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
+ CC, TLI, DAG))
return FMinMax;
+
+ // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
+ // This is conservatively limited to pre-legal-operations to give targets
+ // a chance to reverse the transform if they want to do that. Also, it is
+ // unlikely that the pattern would be formed late, so it's probably not
+ // worth going through the other checks.
+ if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
+ CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
+ N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
+ auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
+ auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
+ if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
+ // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
+ // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
+ //
+ // The IR equivalent of this transform would have this form:
+ // %a = add %x, C
+ // %c = icmp ugt %x, ~C
+ // %r = select %c, -1, %a
+ // =>
+ // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
+ // %u0 = extractvalue %u, 0
+ // %u1 = extractvalue %u, 1
+ // %r = select %u1, -1, %u0
+ SDVTList VTs = DAG.getVTList(VT, VT0);
+ SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
+ return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
+ }
}
- if ((!LegalOperations &&
- TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)) ||
- TLI.isOperationLegal(ISD::SELECT_CC, VT))
- return DAG.getNode(ISD::SELECT_CC, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1),
- N1, N2, N0.getOperand(2));
- return SimplifySelect(SDLoc(N), N0, N1, N2);
+ if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
+ (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)))
+ return DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1, N2,
+ N0.getOperand(2));
+
+ return SimplifySelect(DL, N0, N1, N2);
}
return SDValue();
@@ -5215,7 +7546,7 @@ std::pair<SDValue, SDValue> SplitVSETCC(const SDNode *N, SelectionDAG &DAG) {
// This function assumes all the vselect's arguments are CONCAT_VECTOR
// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
- SDLoc dl(N);
+ SDLoc DL(N);
SDValue Cond = N->getOperand(0);
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
@@ -5237,7 +7568,7 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
// length of the BV and see if all the non-undef nodes are the same.
ConstantSDNode *BottomHalf = nullptr;
for (int i = 0; i < NumElems / 2; ++i) {
- if (Cond->getOperand(i)->getOpcode() == ISD::UNDEF)
+ if (Cond->getOperand(i)->isUndef())
continue;
if (BottomHalf == nullptr)
@@ -5249,7 +7580,7 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
// Do the same for the second half of the BuildVector
ConstantSDNode *TopHalf = nullptr;
for (int i = NumElems / 2; i < NumElems; ++i) {
- if (Cond->getOperand(i)->getOpcode() == ISD::UNDEF)
+ if (Cond->getOperand(i)->isUndef())
continue;
if (TopHalf == nullptr)
@@ -5262,13 +7593,12 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
"One half of the selector was all UNDEFs and the other was all the "
"same value. This should have been addressed before this function.");
return DAG.getNode(
- ISD::CONCAT_VECTORS, dl, VT,
+ ISD::CONCAT_VECTORS, DL, VT,
BottomHalf->isNullValue() ? RHS->getOperand(0) : LHS->getOperand(0),
TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1));
}
SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
-
if (Level >= AfterLegalizeTypes)
return SDValue();
@@ -5288,7 +7618,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) !=
TargetLowering::TypeSplitVector)
return SDValue();
- SDValue MaskLo, MaskHi, Lo, Hi;
+ SDValue MaskLo, MaskHi;
std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG);
EVT LoVT, HiVT;
@@ -5305,6 +7635,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
SDValue DataLo, DataHi;
std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL);
+ SDValue Scale = MSC->getScale();
SDValue BasePtr = MSC->getBasePtr();
SDValue IndexLo, IndexHi;
std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL);
@@ -5314,28 +7645,26 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
MachineMemOperand::MOStore, LoMemVT.getStoreSize(),
Alignment, MSC->getAAInfo(), MSC->getRanges());
- SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo };
- Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
- DL, OpsLo, MMO);
-
- SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi};
- Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
- DL, OpsHi, MMO);
-
- AddToWorklist(Lo.getNode());
- AddToWorklist(Hi.getNode());
+ SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale };
+ SDValue Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
+ DataLo.getValueType(), DL, OpsLo, MMO);
- return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi);
+ // The order of the Scatter operation after split is well defined. The "Hi"
+ // part comes after the "Lo". So these two operations should be chained one
+ // after another.
+ SDValue OpsHi[] = { Lo, DataHi, MaskHi, BasePtr, IndexHi, Scale };
+ return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
+ DL, OpsHi, MMO);
}
SDValue DAGCombiner::visitMSTORE(SDNode *N) {
-
if (Level >= AfterLegalizeTypes)
return SDValue();
MaskedStoreSDNode *MST = dyn_cast<MaskedStoreSDNode>(N);
SDValue Mask = MST->getMask();
SDValue Data = MST->getValue();
+ EVT VT = Data.getValueType();
SDLoc DL(N);
// If the MSTORE data type requires splitting and the mask is provided by a
@@ -5343,18 +7672,14 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
// prevents the type legalizer from unrolling SETCC into scalar comparisons
// and enables future optimizations (e.g. min/max pattern matching on X86).
if (Mask.getOpcode() == ISD::SETCC) {
-
// Check if any splitting is required.
- if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) !=
+ if (TLI.getTypeAction(*DAG.getContext(), VT) !=
TargetLowering::TypeSplitVector)
return SDValue();
SDValue MaskLo, MaskHi, Lo, Hi;
std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG);
- EVT LoVT, HiVT;
- std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MST->getValueType(0));
-
SDValue Chain = MST->getChain();
SDValue Ptr = MST->getBasePtr();
@@ -5364,8 +7689,7 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
// if Alignment is equal to the vector size,
// take the half of it for the second part
unsigned SecondHalfAlignment =
- (Alignment == Data->getValueType(0).getSizeInBits()/8) ?
- Alignment/2 : Alignment;
+ (Alignment == VT.getSizeInBits() / 8) ? Alignment / 2 : Alignment;
EVT LoMemVT, HiMemVT;
std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
@@ -5379,20 +7703,21 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
Alignment, MST->getAAInfo(), MST->getRanges());
Lo = DAG.getMaskedStore(Chain, DL, DataLo, Ptr, MaskLo, LoMemVT, MMO,
- MST->isTruncatingStore());
+ MST->isTruncatingStore(),
+ MST->isCompressingStore());
- unsigned IncrementSize = LoMemVT.getSizeInBits()/8;
- Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
- DAG.getConstant(IncrementSize, DL, Ptr.getValueType()));
+ Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG,
+ MST->isCompressingStore());
+ unsigned HiOffset = LoMemVT.getStoreSize();
- MMO = DAG.getMachineFunction().
- getMachineMemOperand(MST->getPointerInfo(),
- MachineMemOperand::MOStore, HiMemVT.getStoreSize(),
- SecondHalfAlignment, MST->getAAInfo(),
- MST->getRanges());
+ MMO = DAG.getMachineFunction().getMachineMemOperand(
+ MST->getPointerInfo().getWithOffset(HiOffset),
+ MachineMemOperand::MOStore, HiMemVT.getStoreSize(), SecondHalfAlignment,
+ MST->getAAInfo(), MST->getRanges());
Hi = DAG.getMaskedStore(Chain, DL, DataHi, Ptr, MaskHi, HiMemVT, MMO,
- MST->isTruncatingStore());
+ MST->isTruncatingStore(),
+ MST->isCompressingStore());
AddToWorklist(Lo.getNode());
AddToWorklist(Hi.getNode());
@@ -5403,11 +7728,10 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
}
SDValue DAGCombiner::visitMGATHER(SDNode *N) {
-
if (Level >= AfterLegalizeTypes)
return SDValue();
- MaskedGatherSDNode *MGT = dyn_cast<MaskedGatherSDNode>(N);
+ MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
SDValue Mask = MGT->getMask();
SDLoc DL(N);
@@ -5429,9 +7753,9 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
SDValue MaskLo, MaskHi, Lo, Hi;
std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG);
- SDValue Src0 = MGT->getValue();
- SDValue Src0Lo, Src0Hi;
- std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL);
+ SDValue PassThru = MGT->getPassThru();
+ SDValue PassThruLo, PassThruHi;
+ std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, DL);
EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
@@ -5443,6 +7767,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
EVT LoMemVT, HiMemVT;
std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
+ SDValue Scale = MGT->getScale();
SDValue BasePtr = MGT->getBasePtr();
SDValue Index = MGT->getIndex();
SDValue IndexLo, IndexHi;
@@ -5453,13 +7778,13 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
MachineMemOperand::MOLoad, LoMemVT.getStoreSize(),
Alignment, MGT->getAAInfo(), MGT->getRanges());
- SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo };
+ SDValue OpsLo[] = { Chain, PassThruLo, MaskLo, BasePtr, IndexLo, Scale };
Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo,
- MMO);
+ MMO);
- SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi};
+ SDValue OpsHi[] = { Chain, PassThruHi, MaskHi, BasePtr, IndexHi, Scale };
Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi,
- MMO);
+ MMO);
AddToWorklist(Lo.getNode());
AddToWorklist(Hi.getNode());
@@ -5480,7 +7805,6 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
}
SDValue DAGCombiner::visitMLOAD(SDNode *N) {
-
if (Level >= AfterLegalizeTypes)
return SDValue();
@@ -5492,7 +7816,6 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
// SETCC, then split both nodes and its operands before legalization. This
// prevents the type legalizer from unrolling SETCC into scalar comparisons
// and enables future optimizations (e.g. min/max pattern matching on X86).
-
if (Mask.getOpcode() == ISD::SETCC) {
EVT VT = N->getValueType(0);
@@ -5504,9 +7827,9 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
SDValue MaskLo, MaskHi, Lo, Hi;
std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG);
- SDValue Src0 = MLD->getSrc0();
- SDValue Src0Lo, Src0Hi;
- std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL);
+ SDValue PassThru = MLD->getPassThru();
+ SDValue PassThruLo, PassThruHi;
+ std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, DL);
EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MLD->getValueType(0));
@@ -5530,20 +7853,20 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
MachineMemOperand::MOLoad, LoMemVT.getStoreSize(),
Alignment, MLD->getAAInfo(), MLD->getRanges());
- Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, Src0Lo, LoMemVT, MMO,
- ISD::NON_EXTLOAD);
+ Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, PassThruLo, LoMemVT,
+ MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad());
- unsigned IncrementSize = LoMemVT.getSizeInBits()/8;
- Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
- DAG.getConstant(IncrementSize, DL, Ptr.getValueType()));
+ Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG,
+ MLD->isExpandingLoad());
+ unsigned HiOffset = LoMemVT.getStoreSize();
- MMO = DAG.getMachineFunction().
- getMachineMemOperand(MLD->getPointerInfo(),
- MachineMemOperand::MOLoad, HiMemVT.getStoreSize(),
- SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges());
+ MMO = DAG.getMachineFunction().getMachineMemOperand(
+ MLD->getPointerInfo().getWithOffset(HiOffset),
+ MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), SecondHalfAlignment,
+ MLD->getAAInfo(), MLD->getRanges());
- Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, Src0Hi, HiMemVT, MMO,
- ISD::NON_EXTLOAD);
+ Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, PassThruHi, HiMemVT,
+ MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad());
AddToWorklist(Lo.getNode());
AddToWorklist(Hi.getNode());
@@ -5565,12 +7888,66 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
return SDValue();
}
+/// A vector select of 2 constant vectors can be simplified to math/logic to
+/// avoid a variable select instruction and possibly avoid constant loads.
+SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
+ SDValue Cond = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
+ !TLI.convertSelectOfConstantsToMath(VT) ||
+ !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
+ !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
+ return SDValue();
+
+ // Check if we can use the condition value to increment/decrement a single
+ // constant value. This simplifies a select to an add and removes a constant
+ // load/materialization from the general case.
+ bool AllAddOne = true;
+ bool AllSubOne = true;
+ unsigned Elts = VT.getVectorNumElements();
+ for (unsigned i = 0; i != Elts; ++i) {
+ SDValue N1Elt = N1.getOperand(i);
+ SDValue N2Elt = N2.getOperand(i);
+ if (N1Elt.isUndef() || N2Elt.isUndef())
+ continue;
+
+ const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
+ const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
+ if (C1 != C2 + 1)
+ AllAddOne = false;
+ if (C1 != C2 - 1)
+ AllSubOne = false;
+ }
+
+ // Further simplifications for the extra-special cases where the constants are
+ // all 0 or all -1 should be implemented as folds of these patterns.
+ SDLoc DL(N);
+ if (AllAddOne || AllSubOne) {
+ // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
+ // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
+ auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
+ SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
+ return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
+ }
+
+ // The general case for select-of-constants:
+ // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
+ // ...but that only makes sense if a vselect is slower than 2 logic ops, so
+ // leave that to a machine-specific pass.
+ return SDValue();
+}
+
SDValue DAGCombiner::visitVSELECT(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
SDLoc DL(N);
+ if (SDValue V = DAG.simplifySelect(N0, N1, N2))
+ return V;
+
// Canonicalize integer abs.
// vselect (setg[te] X, 0), X, -X ->
// vselect (setgt X, -1), X, -X ->
@@ -5592,47 +7969,66 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (isAbs) {
EVT VT = LHS.getValueType();
+ if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
+ return DAG.getNode(ISD::ABS, DL, VT, LHS);
+
SDValue Shift = DAG.getNode(
ISD::SRA, DL, VT, LHS,
- DAG.getConstant(VT.getScalarType().getSizeInBits() - 1, DL, VT));
+ DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT));
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
AddToWorklist(Shift.getNode());
AddToWorklist(Add.getNode());
return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
}
- }
-
- if (SimplifySelectOps(N, N1, N2))
- return SDValue(N, 0); // Don't revisit N.
- // If the VSELECT result requires splitting and the mask is provided by a
- // SETCC, then split both nodes and its operands before legalization. This
- // prevents the type legalizer from unrolling SETCC into scalar comparisons
- // and enables future optimizations (e.g. min/max pattern matching on X86).
- if (N0.getOpcode() == ISD::SETCC) {
+ // vselect x, y (fcmp lt x, y) -> fminnum x, y
+ // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
+ //
+ // This is OK if we don't care about what happens if either operand is a
+ // NaN.
+ //
EVT VT = N->getValueType(0);
+ if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N0.getOperand(0), N0.getOperand(1))) {
+ ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
+ if (SDValue FMinMax = combineMinNumMaxNum(
+ DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG))
+ return FMinMax;
+ }
- // Check if any splitting is required.
- if (TLI.getTypeAction(*DAG.getContext(), VT) !=
- TargetLowering::TypeSplitVector)
- return SDValue();
-
- SDValue Lo, Hi, CCLo, CCHi, LL, LH, RL, RH;
- std::tie(CCLo, CCHi) = SplitVSETCC(N0.getNode(), DAG);
- std::tie(LL, LH) = DAG.SplitVectorOperand(N, 1);
- std::tie(RL, RH) = DAG.SplitVectorOperand(N, 2);
-
- Lo = DAG.getNode(N->getOpcode(), DL, LL.getValueType(), CCLo, LL, RL);
- Hi = DAG.getNode(N->getOpcode(), DL, LH.getValueType(), CCHi, LH, RH);
-
- // Add the new VSELECT nodes to the work list in case they need to be split
- // again.
- AddToWorklist(Lo.getNode());
- AddToWorklist(Hi.getNode());
-
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+ // If this select has a condition (setcc) with narrower operands than the
+ // select, try to widen the compare to match the select width.
+ // TODO: This should be extended to handle any constant.
+ // TODO: This could be extended to handle non-loading patterns, but that
+ // requires thorough testing to avoid regressions.
+ if (isNullOrNullSplat(RHS)) {
+ EVT NarrowVT = LHS.getValueType();
+ EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
+ EVT SetCCVT = getSetCCResultType(LHS.getValueType());
+ unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
+ unsigned WideWidth = WideVT.getScalarSizeInBits();
+ bool IsSigned = isSignedIntSetCC(CC);
+ auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
+ if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
+ SetCCWidth != 1 && SetCCWidth < WideWidth &&
+ TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
+ TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
+ // Both compare operands can be widened for free. The LHS can use an
+ // extended load, and the RHS is a constant:
+ // vselect (ext (setcc load(X), C)), N1, N2 -->
+ // vselect (setcc extload(X), C'), N1, N2
+ auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
+ SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
+ EVT WideSetCCVT = getSetCCResultType(WideVT);
+ SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
+ return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
+ }
+ }
}
+ if (SimplifySelectOps(N, N1, N2))
+ return SDValue(N, 0); // Don't revisit N.
+
// Fold (vselect (build_vector all_ones), N1, N2) -> N1
if (ISD::isBuildVectorAllOnes(N0.getNode()))
return N1;
@@ -5650,6 +8046,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
return CV;
}
+ if (SDValue V = foldVSelectOfConstants(N))
+ return V;
+
return SDValue();
}
@@ -5666,9 +8065,8 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
return N2;
// Determine if the condition we're dealing with is constant
- SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()),
- N0, N1, CC, SDLoc(N), false);
- if (SCC.getNode()) {
+ if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
+ CC, SDLoc(N), false)) {
AddToWorklist(SCC.getNode());
if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) {
@@ -5676,7 +8074,7 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
return N2; // cond always true -> true val
else
return N3; // cond always false -> false val
- } else if (SCC->getOpcode() == ISD::UNDEF) {
+ } else if (SCC->isUndef()) {
// When the condition is UNDEF, just return the first operand. This is
// coherent the DAG creation, no setcc node is created in this case
return N2;
@@ -5697,19 +8095,43 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
}
SDValue DAGCombiner::visitSETCC(SDNode *N) {
- return SimplifySetCC(N->getValueType(0), N->getOperand(0), N->getOperand(1),
- cast<CondCodeSDNode>(N->getOperand(2))->get(),
- SDLoc(N));
+ // setcc is very commonly used as an argument to brcond. This pattern
+ // also lend itself to numerous combines and, as a result, it is desired
+ // we keep the argument to a brcond as a setcc as much as possible.
+ bool PreferSetCC =
+ N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
+
+ SDValue Combined = SimplifySetCC(
+ N->getValueType(0), N->getOperand(0), N->getOperand(1),
+ cast<CondCodeSDNode>(N->getOperand(2))->get(), SDLoc(N), !PreferSetCC);
+
+ if (!Combined)
+ return SDValue();
+
+ // If we prefer to have a setcc, and we don't, we'll try our best to
+ // recreate one using rebuildSetCC.
+ if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
+ SDValue NewSetCC = rebuildSetCC(Combined);
+
+ // We don't have anything interesting to combine to.
+ if (NewSetCC.getNode() == N)
+ return SDValue();
+
+ if (NewSetCC)
+ return NewSetCC;
+ }
+
+ return Combined;
}
-SDValue DAGCombiner::visitSETCCE(SDNode *N) {
+SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
SDValue Carry = N->getOperand(2);
SDValue Cond = N->getOperand(3);
// If Carry is false, fold to a regular SETCC.
- if (Carry.getOpcode() == ISD::CARRY_FALSE)
+ if (isNullConstant(Carry))
return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
return SDValue();
@@ -5721,43 +8143,47 @@ SDValue DAGCombiner::visitSETCCE(SDNode *N) {
/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
/// Vector extends are not folded if operations are legal; this is to
/// avoid introducing illegal build_vector dag nodes.
-static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
- SelectionDAG &DAG, bool LegalTypes,
- bool LegalOperations) {
+static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
+ SelectionDAG &DAG, bool LegalTypes) {
unsigned Opcode = N->getOpcode();
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
- Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
+ Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
+ Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
&& "Expected EXTEND dag node in input!");
// fold (sext c1) -> c1
// fold (zext c1) -> c1
// fold (aext c1) -> c1
if (isa<ConstantSDNode>(N0))
- return DAG.getNode(Opcode, SDLoc(N), VT, N0).getNode();
+ return DAG.getNode(Opcode, SDLoc(N), VT, N0);
// fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
// fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
// fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
EVT SVT = VT.getScalarType();
- if (!(VT.isVector() &&
- (!LegalTypes || (!LegalOperations && TLI.isTypeLegal(SVT))) &&
+ if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
- return nullptr;
+ return SDValue();
// We can fold this node into a build_vector.
unsigned VTBits = SVT.getSizeInBits();
- unsigned EVTBits = N0->getValueType(0).getScalarType().getSizeInBits();
+ unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
SmallVector<SDValue, 8> Elts;
unsigned NumElts = VT.getVectorNumElements();
SDLoc DL(N);
- for (unsigned i=0; i != NumElts; ++i) {
- SDValue Op = N0->getOperand(i);
- if (Op->getOpcode() == ISD::UNDEF) {
- Elts.push_back(DAG.getUNDEF(SVT));
+ // For zero-extensions, UNDEF elements still guarantee to have the upper
+ // bits set to zero.
+ bool IsZext =
+ Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
+
+ for (unsigned i = 0; i != NumElts; ++i) {
+ SDValue Op = N0.getOperand(i);
+ if (Op.isUndef()) {
+ Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
continue;
}
@@ -5771,19 +8197,19 @@ static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
}
- return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Elts).getNode();
+ return DAG.getBuildVector(VT, DL, Elts);
}
// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
// transformation. Returns true if extension are possible and the above
// mentioned transformation is profitable.
-static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0,
+static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
unsigned ExtOpc,
SmallVectorImpl<SDNode *> &ExtendNodes,
const TargetLowering &TLI) {
bool HasCopyToRegUses = false;
- bool isTruncFree = TLI.isTruncateFree(N->getValueType(0), N0.getValueType());
+ bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
UE = N0.getNode()->use_end();
UI != UE; ++UI) {
@@ -5839,16 +8265,16 @@ static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0,
}
void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
- SDValue Trunc, SDValue ExtLoad, SDLoc DL,
+ SDValue OrigLoad, SDValue ExtLoad,
ISD::NodeType ExtType) {
// Extend SetCC uses if necessary.
- for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) {
- SDNode *SetCC = SetCCs[i];
+ SDLoc DL(ExtLoad);
+ for (SDNode *SetCC : SetCCs) {
SmallVector<SDValue, 4> Ops;
for (unsigned j = 0; j != 2; ++j) {
SDValue SOp = SetCC->getOperand(j);
- if (SOp == Trunc)
+ if (SOp == OrigLoad)
Ops.push_back(ExtLoad);
else
Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
@@ -5897,7 +8323,7 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
return SDValue();
SmallVector<SDNode *, 4> SetCCs;
- if (!ExtendUsesToFormExtLoad(N, N0, N->getOpcode(), SetCCs, TLI))
+ if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
return SDValue();
ISD::LoadExtType ExtType =
@@ -5928,10 +8354,9 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
const unsigned Align = MinAlign(LN0->getAlignment(), Offset);
SDValue SplitLoad = DAG.getExtLoad(
- ExtType, DL, SplitDstVT, LN0->getChain(), BasePtr,
- LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT,
- LN0->isVolatile(), LN0->isNonTemporal(), LN0->isInvariant(),
- Align, LN0->getAAInfo());
+ ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
+ LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
+ LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
BasePtr = DAG.getNode(ISD::ADD, DL, BasePtr.getValueType(), BasePtr,
DAG.getConstant(Stride, DL, BasePtr.getValueType()));
@@ -5943,37 +8368,254 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
+ // Simplify TF.
+ AddToWorklist(NewChain.getNode());
+
CombineTo(N, NewValue);
// Replace uses of the original load (before extension)
// with a truncate of the concatenated sextloaded vectors.
SDValue Trunc =
DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
+ ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
CombineTo(N0.getNode(), Trunc, NewChain);
- ExtendSetCCUses(SetCCs, Trunc, NewValue, DL,
- (ISD::NodeType)N->getOpcode());
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
+// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
+// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
+SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
+ assert(N->getOpcode() == ISD::ZERO_EXTEND);
+ EVT VT = N->getValueType(0);
+
+ // and/or/xor
+ SDValue N0 = N->getOperand(0);
+ if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
+ N0.getOpcode() == ISD::XOR) ||
+ N0.getOperand(1).getOpcode() != ISD::Constant ||
+ (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
+ return SDValue();
+
+ // shl/shr
+ SDValue N1 = N0->getOperand(0);
+ if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
+ N1.getOperand(1).getOpcode() != ISD::Constant ||
+ (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
+ return SDValue();
+
+ // load
+ if (!isa<LoadSDNode>(N1.getOperand(0)))
+ return SDValue();
+ LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
+ EVT MemVT = Load->getMemoryVT();
+ if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
+ Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
+ return SDValue();
+
+
+ // If the shift op is SHL, the logic op must be AND, otherwise the result
+ // will be wrong.
+ if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
+ return SDValue();
+
+ if (!N0.hasOneUse() || !N1.hasOneUse())
+ return SDValue();
+
+ SmallVector<SDNode*, 4> SetCCs;
+ if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
+ ISD::ZERO_EXTEND, SetCCs, TLI))
+ return SDValue();
+
+ // Actually do the transformation.
+ SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
+ Load->getChain(), Load->getBasePtr(),
+ Load->getMemoryVT(), Load->getMemOperand());
+
+ SDLoc DL1(N1);
+ SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
+ N1.getOperand(1));
+
+ APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
+ Mask = Mask.zext(VT.getSizeInBits());
+ SDLoc DL0(N0);
+ SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
+ DAG.getConstant(Mask, DL0, VT));
+
+ ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
+ CombineTo(N, And);
+ if (SDValue(Load, 0).hasOneUse()) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
+ } else {
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
+ Load->getValueType(0), ExtLoad);
+ CombineTo(Load, Trunc, ExtLoad.getValue(1));
+ }
+ return SDValue(N,0); // Return N so it doesn't get rechecked!
+}
+
+/// If we're narrowing or widening the result of a vector select and the final
+/// size is the same size as a setcc (compare) feeding the select, then try to
+/// apply the cast operation to the select's operands because matching vector
+/// sizes for a select condition and other operands should be more efficient.
+SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
+ unsigned CastOpcode = Cast->getOpcode();
+ assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
+ CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
+ CastOpcode == ISD::FP_ROUND) &&
+ "Unexpected opcode for vector select narrowing/widening");
+
+ // We only do this transform before legal ops because the pattern may be
+ // obfuscated by target-specific operations after legalization. Do not create
+ // an illegal select op, however, because that may be difficult to lower.
+ EVT VT = Cast->getValueType(0);
+ if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
+ return SDValue();
+
+ SDValue VSel = Cast->getOperand(0);
+ if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
+ VSel.getOperand(0).getOpcode() != ISD::SETCC)
+ return SDValue();
+
+ // Does the setcc have the same vector size as the casted select?
+ SDValue SetCC = VSel.getOperand(0);
+ EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
+ if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
+ return SDValue();
+
+ // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
+ SDValue A = VSel.getOperand(1);
+ SDValue B = VSel.getOperand(2);
+ SDValue CastA, CastB;
+ SDLoc DL(Cast);
+ if (CastOpcode == ISD::FP_ROUND) {
+ // FP_ROUND (fptrunc) has an extra flag operand to pass along.
+ CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
+ CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
+ } else {
+ CastA = DAG.getNode(CastOpcode, DL, VT, A);
+ CastB = DAG.getNode(CastOpcode, DL, VT, B);
+ }
+ return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
+}
+
+// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
+// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
+static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
+ const TargetLowering &TLI, EVT VT,
+ bool LegalOperations, SDNode *N,
+ SDValue N0, ISD::LoadExtType ExtLoadType) {
+ SDNode *N0Node = N0.getNode();
+ bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
+ : ISD::isZEXTLoad(N0Node);
+ if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
+ !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
+ return {};
+
+ LoadSDNode *LN0 = cast<LoadSDNode>(N0);
+ EVT MemVT = LN0->getMemoryVT();
+ if ((LegalOperations || LN0->isVolatile() || VT.isVector()) &&
+ !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
+ return {};
+
+ SDValue ExtLoad =
+ DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
+ LN0->getBasePtr(), MemVT, LN0->getMemOperand());
+ Combiner.CombineTo(N, ExtLoad);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
+}
+
+// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
+// Only generate vector extloads when 1) they're legal, and 2) they are
+// deemed desirable by the target.
+static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
+ const TargetLowering &TLI, EVT VT,
+ bool LegalOperations, SDNode *N, SDValue N0,
+ ISD::LoadExtType ExtLoadType,
+ ISD::NodeType ExtOpc) {
+ if (!ISD::isNON_EXTLoad(N0.getNode()) ||
+ !ISD::isUNINDEXEDLoad(N0.getNode()) ||
+ ((LegalOperations || VT.isVector() ||
+ cast<LoadSDNode>(N0)->isVolatile()) &&
+ !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
+ return {};
+
+ bool DoXform = true;
+ SmallVector<SDNode *, 4> SetCCs;
+ if (!N0.hasOneUse())
+ DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
+ if (VT.isVector())
+ DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
+ if (!DoXform)
+ return {};
+
+ LoadSDNode *LN0 = cast<LoadSDNode>(N0);
+ SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
+ LN0->getBasePtr(), N0.getValueType(),
+ LN0->getMemOperand());
+ Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
+ // If the load value is used only by N, replace it via CombineTo N.
+ bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
+ Combiner.CombineTo(N, ExtLoad);
+ if (NoReplaceTrunc) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ } else {
+ SDValue Trunc =
+ DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
+ Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ }
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
+}
+
+static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
+ bool LegalOperations) {
+ assert((N->getOpcode() == ISD::SIGN_EXTEND ||
+ N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
+
+ SDValue SetCC = N->getOperand(0);
+ if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
+ !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
+ return SDValue();
+
+ SDValue X = SetCC.getOperand(0);
+ SDValue Ones = SetCC.getOperand(1);
+ ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
+ EVT VT = N->getValueType(0);
+ EVT XVT = X.getValueType();
+ // setge X, C is canonicalized to setgt, so we do not need to match that
+ // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
+ // not require the 'not' op.
+ if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
+ // Invert and smear/shift the sign bit:
+ // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
+ // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
+ SDLoc DL(N);
+ SDValue NotX = DAG.getNOT(DL, X, VT);
+ SDValue ShiftAmount = DAG.getConstant(VT.getSizeInBits() - 1, DL, VT);
+ auto ShiftOpcode = N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
+ return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
+ }
+ return SDValue();
+}
+
SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
+ SDLoc DL(N);
- if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes,
- LegalOperations))
- return SDValue(Res, 0);
+ if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
+ return Res;
// fold (sext (sext x)) -> (sext x)
// fold (sext (aext x)) -> (sext x)
if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
- return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
- N0.getOperand(0));
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
if (N0.getOpcode() == ISD::TRUNCATE) {
// fold (sext (truncate (load x))) -> (sext (smaller load x))
// fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
- SDNode* oye = N0.getNode()->getOperand(0).getNode();
+ SDNode *oye = N0.getOperand(0).getNode();
if (NarrowLoad.getNode() != N0.getNode()) {
CombineTo(N0.getNode(), NarrowLoad);
// CombineTo deleted the truncate, if needed, but not what's under it.
@@ -5985,9 +8627,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// See if the value being truncated is already sign extended. If so, just
// eliminate the trunc/sext pair.
SDValue Op = N0.getOperand(0);
- unsigned OpBits = Op.getValueType().getScalarType().getSizeInBits();
- unsigned MidBits = N0.getValueType().getScalarType().getSizeInBits();
- unsigned DestBits = VT.getScalarType().getSizeInBits();
+ unsigned OpBits = Op.getScalarValueSizeInBits();
+ unsigned MidBits = N0.getScalarValueSizeInBits();
+ unsigned DestBits = VT.getScalarSizeInBits();
unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
if (OpBits == DestBits) {
@@ -5999,12 +8641,12 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
// bits, just sext from i32.
if (NumSignBits > OpBits-MidBits)
- return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, Op);
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
} else {
// Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
// bits, just truncate to i32.
if (NumSignBits > OpBits-MidBits)
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op);
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
}
// fold (sext (truncate x)) -> (sextinreg x).
@@ -6014,65 +8656,26 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
else if (OpBits > DestBits)
Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
- return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, Op,
+ return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
DAG.getValueType(N0.getValueType()));
}
}
- // fold (sext (load x)) -> (sext (truncate (sextload x)))
- // Only generate vector extloads when 1) they're legal, and 2) they are
- // deemed desirable by the target.
- if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
- ((!LegalOperations && !VT.isVector() &&
- !cast<LoadSDNode>(N0)->isVolatile()) ||
- TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, N0.getValueType()))) {
- bool DoXform = true;
- SmallVector<SDNode*, 4> SetCCs;
- if (!N0.hasOneUse())
- DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::SIGN_EXTEND, SetCCs, TLI);
- if (VT.isVector())
- DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
- if (DoXform) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
- LN0->getChain(),
- LN0->getBasePtr(), N0.getValueType(),
- LN0->getMemOperand());
- CombineTo(N, ExtLoad);
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
- N0.getValueType(), ExtLoad);
- CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N),
- ISD::SIGN_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- }
- }
+ // Try to simplify (sext (load x)).
+ if (SDValue foldedExt =
+ tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
+ ISD::SEXTLOAD, ISD::SIGN_EXTEND))
+ return foldedExt;
// fold (sext (load x)) to multiple smaller sextloads.
// Only on illegal but splittable vectors.
if (SDValue ExtLoad = CombineExtLoad(N))
return ExtLoad;
- // fold (sext (sextload x)) -> (sext (truncate (sextload x)))
- // fold (sext ( extload x)) -> (sext (truncate (sextload x)))
- if ((ISD::isSEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) &&
- ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- EVT MemVT = LN0->getMemoryVT();
- if ((!LegalOperations && !LN0->isVolatile()) ||
- TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT)) {
- SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
- LN0->getChain(),
- LN0->getBasePtr(), MemVT,
- LN0->getMemOperand());
- CombineTo(N, ExtLoad);
- CombineTo(N0.getNode(),
- DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
- N0.getValueType(), ExtLoad),
- ExtLoad.getValue(1));
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- }
- }
+ // Try to simplify (sext (sextload x)).
+ if (SDValue foldedExt = tryToFoldExtOfExtload(
+ DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
+ return foldedExt;
// fold (sext (and/or/xor (load x), cst)) ->
// (and/or/xor (sextload x), (sext cst))
@@ -6080,92 +8683,115 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
N0.getOpcode() == ISD::XOR) &&
isa<LoadSDNode>(N0.getOperand(0)) &&
N0.getOperand(1).getOpcode() == ISD::Constant &&
- TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, N0.getValueType()) &&
(!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0));
- if (LN0->getExtensionType() != ISD::ZEXTLOAD && LN0->isUnindexed()) {
- bool DoXform = true;
+ LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
+ EVT MemVT = LN00->getMemoryVT();
+ if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
+ LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
SmallVector<SDNode*, 4> SetCCs;
- if (!N0.hasOneUse())
- DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), ISD::SIGN_EXTEND,
- SetCCs, TLI);
+ bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
+ ISD::SIGN_EXTEND, SetCCs, TLI);
if (DoXform) {
- SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN0), VT,
- LN0->getChain(), LN0->getBasePtr(),
- LN0->getMemoryVT(),
- LN0->getMemOperand());
+ SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
+ LN00->getChain(), LN00->getBasePtr(),
+ LN00->getMemoryVT(),
+ LN00->getMemOperand());
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.sext(VT.getSizeInBits());
- SDLoc DL(N);
SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
ExtLoad, DAG.getConstant(Mask, DL, VT));
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE,
- SDLoc(N0.getOperand(0)),
- N0.getOperand(0).getValueType(), ExtLoad);
+ ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
+ bool NoReplaceTruncAnd = !N0.hasOneUse();
+ bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
CombineTo(N, And);
- CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL,
- ISD::SIGN_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ // If N0 has multiple uses, change other uses as well.
+ if (NoReplaceTruncAnd) {
+ SDValue TruncAnd =
+ DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
+ CombineTo(N0.getNode(), TruncAnd);
+ }
+ if (NoReplaceTrunc) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
+ } else {
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
+ LN00->getValueType(0), ExtLoad);
+ CombineTo(LN00, Trunc, ExtLoad.getValue(1));
+ }
+ return SDValue(N,0); // Return N so it doesn't get rechecked!
}
}
}
+ if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
+ return V;
+
if (N0.getOpcode() == ISD::SETCC) {
- EVT N0VT = N0.getOperand(0).getValueType();
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+ ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
+ EVT N00VT = N0.getOperand(0).getValueType();
+
// sext(setcc) -> sext_in_reg(vsetcc) for vectors.
// Only do this before legalize for now.
if (VT.isVector() && !LegalOperations &&
- TLI.getBooleanContents(N0VT) ==
+ TLI.getBooleanContents(N00VT) ==
TargetLowering::ZeroOrNegativeOneBooleanContent) {
// On some architectures (such as SSE/NEON/etc) the SETCC result type is
// of the same size as the compared operands. Only optimize sext(setcc())
// if this is the case.
- EVT SVT = getSetCCResultType(N0VT);
+ EVT SVT = getSetCCResultType(N00VT);
- // We know that the # elements of the results is the same as the
- // # elements of the compare (and the # elements of the compare result
- // for that matter). Check to see that they are the same size. If so,
- // we know that the element size of the sext'd result matches the
- // element size of the compare operands.
- if (VT.getSizeInBits() == SVT.getSizeInBits())
- return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
- N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
-
- // If the desired elements are smaller or larger than the source
- // elements we can use a matching integer vector type and then
- // truncate/sign extend
- EVT MatchingVectorType = N0VT.changeVectorElementTypeToInteger();
- if (SVT == MatchingVectorType) {
- SDValue VsetCC = DAG.getSetCC(SDLoc(N), MatchingVectorType,
- N0.getOperand(0), N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
- return DAG.getSExtOrTrunc(VsetCC, SDLoc(N), VT);
+ // If we already have the desired type, don't change it.
+ if (SVT != N0.getValueType()) {
+ // We know that the # elements of the results is the same as the
+ // # elements of the compare (and the # elements of the compare result
+ // for that matter). Check to see that they are the same size. If so,
+ // we know that the element size of the sext'd result matches the
+ // element size of the compare operands.
+ if (VT.getSizeInBits() == SVT.getSizeInBits())
+ return DAG.getSetCC(DL, VT, N00, N01, CC);
+
+ // If the desired elements are smaller or larger than the source
+ // elements, we can use a matching integer vector type and then
+ // truncate/sign extend.
+ EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
+ if (SVT == MatchingVecType) {
+ SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
+ return DAG.getSExtOrTrunc(VsetCC, DL, VT);
+ }
}
}
- // sext(setcc x, y, cc) -> (select (setcc x, y, cc), -1, 0)
- unsigned ElementWidth = VT.getScalarType().getSizeInBits();
- SDLoc DL(N);
- SDValue NegOne =
- DAG.getConstant(APInt::getAllOnesValue(ElementWidth), DL, VT);
- SDValue SCC =
- SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1),
- NegOne, DAG.getConstant(0, DL, VT),
- cast<CondCodeSDNode>(N0.getOperand(2))->get(), true);
- if (SCC.getNode()) return SCC;
-
- if (!VT.isVector()) {
- EVT SetCCVT = getSetCCResultType(N0.getOperand(0).getValueType());
- if (!LegalOperations ||
- TLI.isOperationLegal(ISD::SETCC, N0.getOperand(0).getValueType())) {
- SDLoc DL(N);
- ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
- SDValue SetCC = DAG.getSetCC(DL, SetCCVT,
- N0.getOperand(0), N0.getOperand(1), CC);
- return DAG.getSelect(DL, VT, SetCC,
- NegOne, DAG.getConstant(0, DL, VT));
+ // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
+ // Here, T can be 1 or -1, depending on the type of the setcc and
+ // getBooleanContents().
+ unsigned SetCCWidth = N0.getScalarValueSizeInBits();
+
+ // To determine the "true" side of the select, we need to know the high bit
+ // of the value returned by the setcc if it evaluates to true.
+ // If the type of the setcc is i1, then the true case of the select is just
+ // sext(i1 1), that is, -1.
+ // If the type of the setcc is larger (say, i8) then the value of the high
+ // bit depends on getBooleanContents(), so ask TLI for a real "true" value
+ // of the appropriate width.
+ SDValue ExtTrueVal = (SetCCWidth == 1)
+ ? DAG.getAllOnesConstant(DL, VT)
+ : DAG.getBoolConstant(true, DL, VT, N00VT);
+ SDValue Zero = DAG.getConstant(0, DL, VT);
+ if (SDValue SCC =
+ SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
+ return SCC;
+
+ if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
+ EVT SetCCVT = getSetCCResultType(N00VT);
+ // Don't do this transform for i1 because there's a select transform
+ // that would reverse it.
+ // TODO: We should not do this transform at all without a target hook
+ // because a sext is likely cheaper than a select?
+ if (SetCCVT.getScalarSizeInBits() != 1 &&
+ (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
+ SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
+ return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
}
}
}
@@ -6173,54 +8799,53 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// fold (sext x) -> (zext x) if the sign bit is known zero.
if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
DAG.SignBitIsZero(N0))
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0);
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
+
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
return SDValue();
}
// isTruncateOf - If N is a truncate of some other value, return true, record
-// the value being truncated in Op and which of Op's bits are zero in KnownZero.
-// This function computes KnownZero to avoid a duplicated call to
+// the value being truncated in Op and which of Op's bits are zero/one in Known.
+// This function computes KnownBits to avoid a duplicated call to
// computeKnownBits in the caller.
static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
- APInt &KnownZero) {
- APInt KnownOne;
+ KnownBits &Known) {
if (N->getOpcode() == ISD::TRUNCATE) {
Op = N->getOperand(0);
- DAG.computeKnownBits(Op, KnownZero, KnownOne);
+ Known = DAG.computeKnownBits(Op);
return true;
}
- if (N->getOpcode() != ISD::SETCC || N->getValueType(0) != MVT::i1 ||
- cast<CondCodeSDNode>(N->getOperand(2))->get() != ISD::SETNE)
+ if (N.getOpcode() != ISD::SETCC ||
+ N.getValueType().getScalarType() != MVT::i1 ||
+ cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
return false;
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
assert(Op0.getValueType() == Op1.getValueType());
- if (isNullConstant(Op0))
+ if (isNullOrNullSplat(Op0))
Op = Op1;
- else if (isNullConstant(Op1))
+ else if (isNullOrNullSplat(Op1))
Op = Op0;
else
return false;
- DAG.computeKnownBits(Op, KnownZero, KnownOne);
-
- if (!(KnownZero | APInt(Op.getValueSizeInBits(), 1)).isAllOnesValue())
- return false;
+ Known = DAG.computeKnownBits(Op);
- return true;
+ return (Known.Zero | 1).isAllOnesValue();
}
SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
- if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes,
- LegalOperations))
- return SDValue(Res, 0);
+ if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
+ return Res;
// fold (zext (zext x)) -> (zext x)
// fold (zext (aext x)) -> (zext x)
@@ -6231,39 +8856,18 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
// fold (zext (truncate x)) -> (zext x) or
// (zext (truncate x)) -> (truncate x)
// This is valid when the truncated bits of x are already zero.
- // FIXME: We should extend this to work for vectors too.
SDValue Op;
- APInt KnownZero;
- if (!VT.isVector() && isTruncateOf(DAG, N0, Op, KnownZero)) {
+ KnownBits Known;
+ if (isTruncateOf(DAG, N0, Op, Known)) {
APInt TruncatedBits =
- (Op.getValueSizeInBits() == N0.getValueSizeInBits()) ?
- APInt(Op.getValueSizeInBits(), 0) :
- APInt::getBitsSet(Op.getValueSizeInBits(),
- N0.getValueSizeInBits(),
- std::min(Op.getValueSizeInBits(),
- VT.getSizeInBits()));
- if (TruncatedBits == (KnownZero & TruncatedBits)) {
- if (VT.bitsGT(Op.getValueType()))
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, Op);
- if (VT.bitsLT(Op.getValueType()))
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op);
-
- return Op;
- }
- }
-
- // fold (zext (truncate (load x))) -> (zext (smaller load x))
- // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n)))
- if (N0.getOpcode() == ISD::TRUNCATE) {
- if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
- SDNode* oye = N0.getNode()->getOperand(0).getNode();
- if (NarrowLoad.getNode() != N0.getNode()) {
- CombineTo(N0.getNode(), NarrowLoad);
- // CombineTo deleted the truncate, if needed, but not what's under it.
- AddToWorklist(oye);
- }
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- }
+ (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
+ APInt(Op.getScalarValueSizeInBits(), 0) :
+ APInt::getBitsSet(Op.getScalarValueSizeInBits(),
+ N0.getScalarValueSizeInBits(),
+ std::min(Op.getScalarValueSizeInBits(),
+ VT.getScalarSizeInBits()));
+ if (TruncatedBits.isSubsetOf(Known.Zero))
+ return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
}
// fold (zext (truncate x)) -> (and x, mask)
@@ -6271,7 +8875,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
// fold (zext (truncate (load x))) -> (zext (smaller load x))
// fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
- SDNode *oye = N0.getNode()->getOperand(0).getNode();
+ SDNode *oye = N0.getOperand(0).getNode();
if (NarrowLoad.getNode() != N0.getNode()) {
CombineTo(N0.getNode(), NarrowLoad);
// CombineTo deleted the truncate, if needed, but not what's under it.
@@ -6285,26 +8889,27 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
// Try to mask before the extension to avoid having to generate a larger mask,
// possibly over several sub-vectors.
- if (SrcVT.bitsLT(VT)) {
+ if (SrcVT.bitsLT(VT) && VT.isVector()) {
if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
SDValue Op = N0.getOperand(0);
Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
AddToWorklist(Op.getNode());
- return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
+ SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
+ // Transfer the debug info; the new node is equivalent to N0.
+ DAG.transferDbgValues(N0, ZExtOrTrunc);
+ return ZExtOrTrunc;
}
}
if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
- SDValue Op = N0.getOperand(0);
- if (SrcVT.bitsLT(VT)) {
- Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, Op);
- AddToWorklist(Op.getNode());
- } else if (SrcVT.bitsGT(VT)) {
- Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op);
- AddToWorklist(Op.getNode());
- }
- return DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
+ SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
+ AddToWorklist(Op.getNode());
+ SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
+ // We may safely transfer the debug info describing the truncate node over
+ // to the equivalent and operation.
+ DAG.transferDbgValues(N0, And);
+ return And;
}
}
@@ -6317,11 +8922,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
N0.getValueType()) ||
!TLI.isZExtFree(N0.getValueType(), VT))) {
SDValue X = N0.getOperand(0).getOperand(0);
- if (X.getValueType().bitsLT(VT)) {
- X = DAG.getNode(ISD::ANY_EXTEND, SDLoc(X), VT, X);
- } else if (X.getValueType().bitsGT(VT)) {
- X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
- }
+ X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.zext(VT.getSizeInBits());
SDLoc DL(N);
@@ -6329,35 +8930,11 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
X, DAG.getConstant(Mask, DL, VT));
}
- // fold (zext (load x)) -> (zext (truncate (zextload x)))
- // Only generate vector extloads when 1) they're legal, and 2) they are
- // deemed desirable by the target.
- if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
- ((!LegalOperations && !VT.isVector() &&
- !cast<LoadSDNode>(N0)->isVolatile()) ||
- TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, N0.getValueType()))) {
- bool DoXform = true;
- SmallVector<SDNode*, 4> SetCCs;
- if (!N0.hasOneUse())
- DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::ZERO_EXTEND, SetCCs, TLI);
- if (VT.isVector())
- DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
- if (DoXform) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT,
- LN0->getChain(),
- LN0->getBasePtr(), N0.getValueType(),
- LN0->getMemOperand());
- CombineTo(N, ExtLoad);
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
- N0.getValueType(), ExtLoad);
- CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
-
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N),
- ISD::ZERO_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- }
- }
+ // Try to simplify (zext (load x)).
+ if (SDValue foldedExt =
+ tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
+ ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
+ return foldedExt;
// fold (zext (load x)) to multiple smaller zextloads.
// Only on illegal but splittable vectors.
@@ -6372,120 +8949,110 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
N0.getOpcode() == ISD::XOR) &&
isa<LoadSDNode>(N0.getOperand(0)) &&
N0.getOperand(1).getOpcode() == ISD::Constant &&
- TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, N0.getValueType()) &&
(!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0));
- if (LN0->getExtensionType() != ISD::SEXTLOAD && LN0->isUnindexed()) {
+ LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
+ EVT MemVT = LN00->getMemoryVT();
+ if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
+ LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
bool DoXform = true;
SmallVector<SDNode*, 4> SetCCs;
if (!N0.hasOneUse()) {
if (N0.getOpcode() == ISD::AND) {
auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
- auto NarrowLoad = false;
EVT LoadResultTy = AndC->getValueType(0);
- EVT ExtVT, LoadedVT;
- if (isAndLoadExtLoad(AndC, LN0, LoadResultTy, ExtVT, LoadedVT,
- NarrowLoad))
+ EVT ExtVT;
+ if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
DoXform = false;
}
- if (DoXform)
- DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0),
- ISD::ZERO_EXTEND, SetCCs, TLI);
}
+ if (DoXform)
+ DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
+ ISD::ZERO_EXTEND, SetCCs, TLI);
if (DoXform) {
- SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), VT,
- LN0->getChain(), LN0->getBasePtr(),
- LN0->getMemoryVT(),
- LN0->getMemOperand());
+ SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
+ LN00->getChain(), LN00->getBasePtr(),
+ LN00->getMemoryVT(),
+ LN00->getMemOperand());
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.zext(VT.getSizeInBits());
SDLoc DL(N);
SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
ExtLoad, DAG.getConstant(Mask, DL, VT));
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE,
- SDLoc(N0.getOperand(0)),
- N0.getOperand(0).getValueType(), ExtLoad);
+ ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
+ bool NoReplaceTruncAnd = !N0.hasOneUse();
+ bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
CombineTo(N, And);
- CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL,
- ISD::ZERO_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ // If N0 has multiple uses, change other uses as well.
+ if (NoReplaceTruncAnd) {
+ SDValue TruncAnd =
+ DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
+ CombineTo(N0.getNode(), TruncAnd);
+ }
+ if (NoReplaceTrunc) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
+ } else {
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
+ LN00->getValueType(0), ExtLoad);
+ CombineTo(LN00, Trunc, ExtLoad.getValue(1));
+ }
+ return SDValue(N,0); // Return N so it doesn't get rechecked!
}
}
}
- // fold (zext (zextload x)) -> (zext (truncate (zextload x)))
- // fold (zext ( extload x)) -> (zext (truncate (zextload x)))
- if ((ISD::isZEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) &&
- ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- EVT MemVT = LN0->getMemoryVT();
- if ((!LegalOperations && !LN0->isVolatile()) ||
- TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) {
- SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT,
- LN0->getChain(),
- LN0->getBasePtr(), MemVT,
- LN0->getMemOperand());
- CombineTo(N, ExtLoad);
- CombineTo(N0.getNode(),
- DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(),
- ExtLoad),
- ExtLoad.getValue(1));
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
- }
- }
+ // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
+ // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
+ if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
+ return ZExtLoad;
+
+ // Try to simplify (zext (zextload x)).
+ if (SDValue foldedExt = tryToFoldExtOfExtload(
+ DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
+ return foldedExt;
+
+ if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
+ return V;
if (N0.getOpcode() == ISD::SETCC) {
+ // Only do this before legalize for now.
if (!LegalOperations && VT.isVector() &&
N0.getValueType().getVectorElementType() == MVT::i1) {
- EVT N0VT = N0.getOperand(0).getValueType();
- if (getSetCCResultType(N0VT) == N0.getValueType())
+ EVT N00VT = N0.getOperand(0).getValueType();
+ if (getSetCCResultType(N00VT) == N0.getValueType())
return SDValue();
- // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors.
- // Only do this before legalize for now.
- EVT EltVT = VT.getVectorElementType();
+ // We know that the # elements of the results is the same as the #
+ // elements of the compare (and the # elements of the compare result for
+ // that matter). Check to see that they are the same size. If so, we know
+ // that the element size of the sext'd result matches the element size of
+ // the compare operands.
SDLoc DL(N);
- SmallVector<SDValue,8> OneOps(VT.getVectorNumElements(),
- DAG.getConstant(1, DL, EltVT));
- if (VT.getSizeInBits() == N0VT.getSizeInBits())
- // We know that the # elements of the results is the same as the
- // # elements of the compare (and the # elements of the compare result
- // for that matter). Check to see that they are the same size. If so,
- // we know that the element size of the sext'd result matches the
- // element size of the compare operands.
- return DAG.getNode(ISD::AND, DL, VT,
- DAG.getSetCC(DL, VT, N0.getOperand(0),
- N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get()),
- DAG.getNode(ISD::BUILD_VECTOR, DL, VT,
- OneOps));
+ SDValue VecOnes = DAG.getConstant(1, DL, VT);
+ if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
+ // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors.
+ SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
+ N0.getOperand(1), N0.getOperand(2));
+ return DAG.getNode(ISD::AND, DL, VT, VSetCC, VecOnes);
+ }
// If the desired elements are smaller or larger than the source
// elements we can use a matching integer vector type and then
- // truncate/sign extend
- EVT MatchingElementType =
- EVT::getIntegerVT(*DAG.getContext(),
- N0VT.getScalarType().getSizeInBits());
- EVT MatchingVectorType =
- EVT::getVectorVT(*DAG.getContext(), MatchingElementType,
- N0VT.getVectorNumElements());
+ // truncate/sign extend.
+ EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
SDValue VsetCC =
- DAG.getSetCC(DL, MatchingVectorType, N0.getOperand(0),
- N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
- return DAG.getNode(ISD::AND, DL, VT,
- DAG.getSExtOrTrunc(VsetCC, DL, VT),
- DAG.getNode(ISD::BUILD_VECTOR, DL, VT, OneOps));
+ DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
+ N0.getOperand(1), N0.getOperand(2));
+ return DAG.getNode(ISD::AND, DL, VT, DAG.getSExtOrTrunc(VsetCC, DL, VT),
+ VecOnes);
}
// zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
SDLoc DL(N);
- SDValue SCC =
- SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1),
- DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
- cast<CondCodeSDNode>(N0.getOperand(2))->get(), true);
- if (SCC.getNode()) return SCC;
+ if (SDValue SCC = SimplifySelectCC(
+ DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
+ DAG.getConstant(0, DL, VT),
+ cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
+ return SCC;
}
// (zext (shl (zext x), cst)) -> (shl (zext x), cst)
@@ -6499,8 +9066,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
SDValue InnerZExt = N0.getOperand(0);
// If the original shl may be shifting out bits, do not perform this
// transformation.
- unsigned KnownZeroBits = InnerZExt.getValueType().getSizeInBits() -
- InnerZExt.getOperand(0).getValueType().getSizeInBits();
+ unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
+ InnerZExt.getOperand(0).getValueSizeInBits();
if (ShAmtVal > KnownZeroBits)
return SDValue();
}
@@ -6516,6 +9083,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
ShAmt);
}
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -6523,9 +9093,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
- if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes,
- LegalOperations))
- return SDValue(Res, 0);
+ if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
+ return Res;
// fold (aext (aext x)) -> (aext x)
// fold (aext (zext x)) -> (zext x)
@@ -6539,7 +9108,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
// fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
if (N0.getOpcode() == ISD::TRUNCATE) {
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
- SDNode* oye = N0.getNode()->getOperand(0).getNode();
+ SDNode *oye = N0.getOperand(0).getNode();
if (NarrowLoad.getNode() != N0.getNode()) {
CombineTo(N0.getNode(), NarrowLoad);
// CombineTo deleted the truncate, if needed, but not what's under it.
@@ -6550,14 +9119,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
}
// fold (aext (truncate x))
- if (N0.getOpcode() == ISD::TRUNCATE) {
- SDValue TruncOp = N0.getOperand(0);
- if (TruncOp.getValueType() == VT)
- return TruncOp; // x iff x size == zext size.
- if (TruncOp.getValueType().bitsGT(VT))
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, TruncOp);
- return DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, TruncOp);
- }
+ if (N0.getOpcode() == ISD::TRUNCATE)
+ return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
// Fold (aext (and (trunc x), cst)) -> (and x, cst)
// if the trunc is not free.
@@ -6566,15 +9129,11 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
N0.getOperand(1).getOpcode() == ISD::Constant &&
!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
N0.getValueType())) {
+ SDLoc DL(N);
SDValue X = N0.getOperand(0).getOperand(0);
- if (X.getValueType().bitsLT(VT)) {
- X = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, X);
- } else if (X.getValueType().bitsGT(VT)) {
- X = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
- }
+ X = DAG.getAnyExtOrTrunc(X, DL, VT);
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.zext(VT.getSizeInBits());
- SDLoc DL(N);
return DAG.getNode(ISD::AND, DL, VT,
X, DAG.getConstant(Mask, DL, VT));
}
@@ -6589,29 +9148,34 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
bool DoXform = true;
SmallVector<SDNode*, 4> SetCCs;
if (!N0.hasOneUse())
- DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
+ DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs,
+ TLI);
if (DoXform) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
LN0->getChain(),
LN0->getBasePtr(), N0.getValueType(),
LN0->getMemOperand());
+ ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
+ // If the load value is used only by N, replace it via CombineTo N.
+ bool NoReplaceTrunc = N0.hasOneUse();
CombineTo(N, ExtLoad);
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
- N0.getValueType(), ExtLoad);
- CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N),
- ISD::ANY_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ if (NoReplaceTrunc) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ } else {
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
+ N0.getValueType(), ExtLoad);
+ CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ }
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
// fold (aext (zextload x)) -> (aext (truncate (zextload x)))
// fold (aext (sextload x)) -> (aext (truncate (sextload x)))
// fold (aext ( extload x)) -> (aext (truncate (extload x)))
- if (N0.getOpcode() == ISD::LOAD &&
- !ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
- N0.hasOneUse()) {
+ if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
+ ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
ISD::LoadExtType ExtType = LN0->getExtensionType();
EVT MemVT = LN0->getMemoryVT();
@@ -6620,10 +9184,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
VT, LN0->getChain(), LN0->getBasePtr(),
MemVT, LN0->getMemOperand());
CombineTo(N, ExtLoad);
- CombineTo(N0.getNode(),
- DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
- N0.getValueType(), ExtLoad),
- ExtLoad.getValue(1));
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
@@ -6635,89 +9196,104 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
// aext(setcc) -> aext(vsetcc)
// Only do this before legalize for now.
if (VT.isVector() && !LegalOperations) {
- EVT N0VT = N0.getOperand(0).getValueType();
- // We know that the # elements of the results is the same as the
- // # elements of the compare (and the # elements of the compare result
- // for that matter). Check to see that they are the same size. If so,
- // we know that the element size of the sext'd result matches the
- // element size of the compare operands.
- if (VT.getSizeInBits() == N0VT.getSizeInBits())
+ EVT N00VT = N0.getOperand(0).getValueType();
+ if (getSetCCResultType(N00VT) == N0.getValueType())
+ return SDValue();
+
+ // We know that the # elements of the results is the same as the
+ // # elements of the compare (and the # elements of the compare result
+ // for that matter). Check to see that they are the same size. If so,
+ // we know that the element size of the sext'd result matches the
+ // element size of the compare operands.
+ if (VT.getSizeInBits() == N00VT.getSizeInBits())
return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
N0.getOperand(1),
cast<CondCodeSDNode>(N0.getOperand(2))->get());
+
// If the desired elements are smaller or larger than the source
// elements we can use a matching integer vector type and then
// truncate/any extend
- else {
- EVT MatchingVectorType = N0VT.changeVectorElementTypeToInteger();
- SDValue VsetCC =
- DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
- N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
- return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
- }
+ EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
+ SDValue VsetCC =
+ DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
+ N0.getOperand(1),
+ cast<CondCodeSDNode>(N0.getOperand(2))->get());
+ return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
}
// aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
SDLoc DL(N);
- SDValue SCC =
- SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1),
- DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
- cast<CondCodeSDNode>(N0.getOperand(2))->get(), true);
- if (SCC.getNode())
+ if (SDValue SCC = SimplifySelectCC(
+ DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
+ DAG.getConstant(0, DL, VT),
+ cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
return SCC;
}
return SDValue();
}
-/// See if the specified operand can be simplified with the knowledge that only
-/// the bits specified by Mask are used. If so, return the simpler operand,
-/// otherwise return a null SDValue.
-SDValue DAGCombiner::GetDemandedBits(SDValue V, const APInt &Mask) {
- switch (V.getOpcode()) {
- default: break;
- case ISD::Constant: {
- const ConstantSDNode *CV = cast<ConstantSDNode>(V.getNode());
- assert(CV && "Const value should be ConstSDNode.");
- const APInt &CVal = CV->getAPIntValue();
- APInt NewVal = CVal & Mask;
- if (NewVal != CVal)
- return DAG.getConstant(NewVal, SDLoc(V), V.getValueType());
- break;
- }
- case ISD::OR:
- case ISD::XOR:
- // If the LHS or RHS don't contribute bits to the or, drop them.
- if (DAG.MaskedValueIsZero(V.getOperand(0), Mask))
- return V.getOperand(1);
- if (DAG.MaskedValueIsZero(V.getOperand(1), Mask))
- return V.getOperand(0);
- break;
- case ISD::SRL:
- // Only look at single-use SRLs.
- if (!V.getNode()->hasOneUse())
- break;
- if (ConstantSDNode *RHSC = getAsNonOpaqueConstant(V.getOperand(1))) {
- // See if we can recursively simplify the LHS.
- unsigned Amt = RHSC->getZExtValue();
+SDValue DAGCombiner::visitAssertExt(SDNode *N) {
+ unsigned Opcode = N->getOpcode();
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT AssertVT = cast<VTSDNode>(N1)->getVT();
+
+ // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
+ if (N0.getOpcode() == Opcode &&
+ AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
+ return N0;
- // Watch out for shift count overflow though.
- if (Amt >= Mask.getBitWidth()) break;
- APInt NewMask = Mask << Amt;
- if (SDValue SimplifyLHS = GetDemandedBits(V.getOperand(0), NewMask))
- return DAG.getNode(ISD::SRL, SDLoc(V), V.getValueType(),
- SimplifyLHS, V.getOperand(1));
+ if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
+ N0.getOperand(0).getOpcode() == Opcode) {
+ // We have an assert, truncate, assert sandwich. Make one stronger assert
+ // by asserting on the smallest asserted type to the larger source type.
+ // This eliminates the later assert:
+ // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
+ // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
+ SDValue BigA = N0.getOperand(0);
+ EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
+ assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
+ "Asserting zero/sign-extended bits to a type larger than the "
+ "truncated destination does not provide information");
+
+ SDLoc DL(N);
+ EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
+ SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
+ SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
+ BigA.getOperand(0), MinAssertVTVal);
+ return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
+ }
+
+ // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
+ // than X. Just move the AssertZext in front of the truncate and drop the
+ // AssertSExt.
+ if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
+ N0.getOperand(0).getOpcode() == ISD::AssertSext &&
+ Opcode == ISD::AssertZext) {
+ SDValue BigA = N0.getOperand(0);
+ EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
+ assert(BigA_AssertVT.bitsLE(N0.getValueType()) &&
+ "Asserting zero/sign-extended bits to a type larger than the "
+ "truncated destination does not provide information");
+
+ if (AssertVT.bitsLT(BigA_AssertVT)) {
+ SDLoc DL(N);
+ SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
+ BigA.getOperand(0), N1);
+ return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
}
}
+
return SDValue();
}
/// If the result of a wider load is shifted to right of N bits and then
/// truncated to a narrower type and where N is a multiple of number of bits of
/// the narrower type, transform it to a narrower load from address + N / num of
-/// bits of new type. If the result is to be extended, also fold the extension
-/// to form a extending load.
+/// bits of new type. Also narrow the load if the result is masked with an AND
+/// to effectively produce a smaller type. If the result is to be extended, also
+/// fold the extension to form a extending load.
SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
unsigned Opc = N->getOpcode();
@@ -6730,56 +9306,100 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
if (VT.isVector())
return SDValue();
+ unsigned ShAmt = 0;
+ bool HasShiftedOffset = false;
// Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
// extended to VT.
if (Opc == ISD::SIGN_EXTEND_INREG) {
ExtType = ISD::SEXTLOAD;
ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
} else if (Opc == ISD::SRL) {
- // Another special-case: SRL is basically zero-extending a narrower value.
+ // Another special-case: SRL is basically zero-extending a narrower value,
+ // or it maybe shifting a higher subword, half or byte into the lowest
+ // bits.
ExtType = ISD::ZEXTLOAD;
N0 = SDValue(N, 0);
- ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
- if (!N01) return SDValue();
- ExtVT = EVT::getIntegerVT(*DAG.getContext(),
- VT.getSizeInBits() - N01->getZExtValue());
- }
- if (LegalOperations && !TLI.isLoadExtLegal(ExtType, VT, ExtVT))
- return SDValue();
- unsigned EVTBits = ExtVT.getSizeInBits();
+ auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
+ auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
+ if (!N01 || !LN0)
+ return SDValue();
- // Do not generate loads of non-round integer types since these can
- // be expensive (and would be wrong if the type is not byte sized).
- if (!ExtVT.isRound())
- return SDValue();
+ uint64_t ShiftAmt = N01->getZExtValue();
+ uint64_t MemoryWidth = LN0->getMemoryVT().getSizeInBits();
+ if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
+ ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
+ else
+ ExtVT = EVT::getIntegerVT(*DAG.getContext(),
+ VT.getSizeInBits() - ShiftAmt);
+ } else if (Opc == ISD::AND) {
+ // An AND with a constant mask is the same as a truncate + zero-extend.
+ auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ if (!AndC)
+ return SDValue();
+
+ const APInt &Mask = AndC->getAPIntValue();
+ unsigned ActiveBits = 0;
+ if (Mask.isMask()) {
+ ActiveBits = Mask.countTrailingOnes();
+ } else if (Mask.isShiftedMask()) {
+ ShAmt = Mask.countTrailingZeros();
+ APInt ShiftedMask = Mask.lshr(ShAmt);
+ ActiveBits = ShiftedMask.countTrailingOnes();
+ HasShiftedOffset = true;
+ } else
+ return SDValue();
+
+ ExtType = ISD::ZEXTLOAD;
+ ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
+ }
- unsigned ShAmt = 0;
if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
- if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
- ShAmt = N01->getZExtValue();
+ SDValue SRL = N0;
+ if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
+ ShAmt = ConstShift->getZExtValue();
+ unsigned EVTBits = ExtVT.getSizeInBits();
// Is the shift amount a multiple of size of VT?
if ((ShAmt & (EVTBits-1)) == 0) {
N0 = N0.getOperand(0);
// Is the load width a multiple of size of VT?
- if ((N0.getValueType().getSizeInBits() & (EVTBits-1)) != 0)
+ if ((N0.getValueSizeInBits() & (EVTBits-1)) != 0)
return SDValue();
}
// At this point, we must have a load or else we can't do the transform.
if (!isa<LoadSDNode>(N0)) return SDValue();
+ auto *LN0 = cast<LoadSDNode>(N0);
+
// Because a SRL must be assumed to *need* to zero-extend the high bits
// (as opposed to anyext the high bits), we can't combine the zextload
// lowering of SRL and an sextload.
- if (cast<LoadSDNode>(N0)->getExtensionType() == ISD::SEXTLOAD)
+ if (LN0->getExtensionType() == ISD::SEXTLOAD)
return SDValue();
// If the shift amount is larger than the input type then we're not
// accessing any of the loaded bytes. If the load was a zextload/extload
// then the result of the shift+trunc is zero/undef (handled elsewhere).
- if (ShAmt >= cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits())
+ if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
return SDValue();
+
+ // If the SRL is only used by a masking AND, we may be able to adjust
+ // the ExtVT to make the AND redundant.
+ SDNode *Mask = *(SRL->use_begin());
+ if (Mask->getOpcode() == ISD::AND &&
+ isa<ConstantSDNode>(Mask->getOperand(1))) {
+ const APInt &ShiftMask =
+ cast<ConstantSDNode>(Mask->getOperand(1))->getAPIntValue();
+ if (ShiftMask.isMask()) {
+ EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
+ ShiftMask.countTrailingOnes());
+ // If the mask is smaller, recompute the type.
+ if ((ExtVT.getSizeInBits() > MaskedVT.getSizeInBits()) &&
+ TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
+ ExtVT = MaskedVT;
+ }
+ }
}
}
@@ -6794,52 +9414,26 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
}
}
- // If we haven't found a load, we can't narrow it. Don't transform one with
- // multiple uses, this would require adding a new load.
- if (!isa<LoadSDNode>(N0) || !N0.hasOneUse())
+ // If we haven't found a load, we can't narrow it.
+ if (!isa<LoadSDNode>(N0))
return SDValue();
- // Don't change the width of a volatile load.
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- if (LN0->isVolatile())
- return SDValue();
-
- // Verify that we are actually reducing a load width here.
- if (LN0->getMemoryVT().getSizeInBits() < EVTBits)
- return SDValue();
-
- // For the transform to be legal, the load must produce only two values
- // (the value loaded and the chain). Don't transform a pre-increment
- // load, for example, which produces an extra value. Otherwise the
- // transformation is not equivalent, and the downstream logic to replace
- // uses gets things wrong.
- if (LN0->getNumValues() > 2)
+ if (!isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
return SDValue();
- // If the load that we're shrinking is an extload and we're not just
- // discarding the extension we can't simply shrink the load. Bail.
- // TODO: It would be possible to merge the extensions in some cases.
- if (LN0->getExtensionType() != ISD::NON_EXTLOAD &&
- LN0->getMemoryVT().getSizeInBits() < ExtVT.getSizeInBits() + ShAmt)
- return SDValue();
-
- if (!TLI.shouldReduceLoadWidth(LN0, ExtType, ExtVT))
- return SDValue();
-
- EVT PtrType = N0.getOperand(1).getValueType();
-
- if (PtrType == MVT::Untyped || PtrType.isExtended())
- // It's not possible to generate a constant of extended or untyped type.
- return SDValue();
+ auto AdjustBigEndianShift = [&](unsigned ShAmt) {
+ unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits();
+ unsigned EVTStoreBits = ExtVT.getStoreSizeInBits();
+ return LVTStoreBits - EVTStoreBits - ShAmt;
+ };
// For big endian targets, we need to adjust the offset to the pointer to
// load the correct bytes.
- if (DAG.getDataLayout().isBigEndian()) {
- unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits();
- unsigned EVTStoreBits = ExtVT.getStoreSizeInBits();
- ShAmt = LVTStoreBits - EVTStoreBits - ShAmt;
- }
+ if (DAG.getDataLayout().isBigEndian())
+ ShAmt = AdjustBigEndianShift(ShAmt);
+ EVT PtrType = N0.getOperand(1).getValueType();
uint64_t PtrOff = ShAmt / 8;
unsigned NewAlign = MinAlign(LN0->getAlignment(), PtrOff);
SDLoc DL(LN0);
@@ -6849,20 +9443,19 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
SDValue NewPtr = DAG.getNode(ISD::ADD, DL,
PtrType, LN0->getBasePtr(),
DAG.getConstant(PtrOff, DL, PtrType),
- &Flags);
+ Flags);
AddToWorklist(NewPtr.getNode());
SDValue Load;
if (ExtType == ISD::NON_EXTLOAD)
- Load = DAG.getLoad(VT, SDLoc(N0), LN0->getChain(), NewPtr,
- LN0->getPointerInfo().getWithOffset(PtrOff),
- LN0->isVolatile(), LN0->isNonTemporal(),
- LN0->isInvariant(), NewAlign, LN0->getAAInfo());
+ Load = DAG.getLoad(VT, SDLoc(N0), LN0->getChain(), NewPtr,
+ LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
+ LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
else
- Load = DAG.getExtLoad(ExtType, SDLoc(N0), VT, LN0->getChain(),NewPtr,
- LN0->getPointerInfo().getWithOffset(PtrOff),
- ExtVT, LN0->isVolatile(), LN0->isNonTemporal(),
- LN0->isInvariant(), NewAlign, LN0->getAAInfo());
+ Load = DAG.getExtLoad(ExtType, SDLoc(N0), VT, LN0->getChain(), NewPtr,
+ LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
+ NewAlign, LN0->getMemOperand()->getFlags(),
+ LN0->getAAInfo());
// Replace the old load's chain with the new load's chain.
WorklistRemover DeadNodes(*this);
@@ -6886,6 +9479,21 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
}
+ if (HasShiftedOffset) {
+ // Recalculate the shift amount after it has been altered to calculate
+ // the offset.
+ if (DAG.getDataLayout().isBigEndian())
+ ShAmt = AdjustBigEndianShift(ShAmt);
+
+ // We're using a shifted mask, so the load now has an offset. This means
+ // that data has been loaded into the lower bytes than it would have been
+ // before, so we need to shl the loaded data into the correct position in the
+ // register.
+ SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
+ Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
+ }
+
// Return the new loaded value.
return Result;
}
@@ -6895,14 +9503,14 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
EVT EVT = cast<VTSDNode>(N1)->getVT();
- unsigned VTBits = VT.getScalarType().getSizeInBits();
- unsigned EVTBits = EVT.getScalarType().getSizeInBits();
+ unsigned VTBits = VT.getScalarSizeInBits();
+ unsigned EVTBits = EVT.getScalarSizeInBits();
if (N0.isUndef())
return DAG.getUNDEF(VT);
// fold (sext_in_reg c1) -> c1
- if (isConstantIntBuildVectorOrConstantInt(N0))
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
// If the input is already sign extended, just drop the extension.
@@ -6917,17 +9525,40 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
// fold (sext_in_reg (sext x)) -> (sext x)
// fold (sext_in_reg (aext x)) -> (sext x)
- // if x is small enough.
+ // if x is small enough or if we know that x has more than 1 sign bit and the
+ // sign_extend_inreg is extending from one of them.
if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
SDValue N00 = N0.getOperand(0);
- if (N00.getValueType().getScalarType().getSizeInBits() <= EVTBits &&
+ unsigned N00Bits = N00.getScalarValueSizeInBits();
+ if ((N00Bits <= EVTBits ||
+ (N00Bits - DAG.ComputeNumSignBits(N00)) < EVTBits) &&
+ (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
+ return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
+ }
+
+ // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
+ if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
+ N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
+ N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) &&
+ N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) {
+ if (!LegalOperations ||
+ TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT))
+ return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
+ N0.getOperand(0));
+ }
+
+ // fold (sext_in_reg (zext x)) -> (sext x)
+ // iff we are extending the source sign bit.
+ if (N0.getOpcode() == ISD::ZERO_EXTEND) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getScalarValueSizeInBits() == EVTBits &&
(!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
}
// fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
- if (DAG.MaskedValueIsZero(N0, APInt::getBitsSet(VTBits, EVTBits-1, EVTBits)))
- return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT);
+ if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1)))
+ return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType());
// fold operands of sext_in_reg based on knowledge that the top bits are not
// demanded.
@@ -6955,10 +9586,14 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
}
// fold (sext_inreg (extload x)) -> (sextload x)
+ // If sextload is not supported by target, we can only do the combine when
+ // load has one use. Doing otherwise can block folding the extload with other
+ // extends that the target does support.
if (ISD::isEXTLoad(N0.getNode()) &&
ISD::isUNINDEXEDLoad(N0.getNode()) &&
EVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
- ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile()) ||
+ ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile() &&
+ N0.hasOneUse()) ||
TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
@@ -6988,9 +9623,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
// Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
if (EVTBits <= 16 && N0.getOpcode() == ISD::OR) {
- SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
- N0.getOperand(1), false);
- if (BSwap.getNode())
+ if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
+ N0.getOperand(1), false))
return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
BSwap, N1);
}
@@ -7002,12 +9636,30 @@ SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
- if (N0.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef())
+ return DAG.getUNDEF(VT);
+
+ if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
+ return Res;
+
+ if (SimplifyDemandedVectorElts(SDValue(N, 0)))
+ return SDValue(N, 0);
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+
+ if (N0.isUndef())
return DAG.getUNDEF(VT);
- if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes,
- LegalOperations))
- return SDValue(Res, 0);
+ if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
+ return Res;
+
+ if (SimplifyDemandedVectorElts(SDValue(N, 0)))
+ return SDValue(N, 0);
return SDValue();
}
@@ -7020,28 +9672,37 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// noop truncate
if (N0.getValueType() == N->getValueType(0))
return N0;
- // fold (truncate c1) -> c1
- if (isConstantIntBuildVectorOrConstantInt(N0))
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
+
// fold (truncate (truncate x)) -> (truncate x)
if (N0.getOpcode() == ISD::TRUNCATE)
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
+
+ // fold (truncate c1) -> c1
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
+ SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
+ if (C.getNode() != N)
+ return C;
+ }
+
// fold (truncate (ext x)) -> (ext x) or (truncate x) or x
if (N0.getOpcode() == ISD::ZERO_EXTEND ||
N0.getOpcode() == ISD::SIGN_EXTEND ||
N0.getOpcode() == ISD::ANY_EXTEND) {
+ // if the source is smaller than the dest, we still need an extend.
if (N0.getOperand(0).getValueType().bitsLT(VT))
- // if the source is smaller than the dest, we still need an extend
- return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
- N0.getOperand(0));
+ return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
+ // if the source is larger than the dest, than we just need the truncate.
if (N0.getOperand(0).getValueType().bitsGT(VT))
- // if the source is larger than the dest, than we just need the truncate
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
// if the source and dest are the same type, we can drop both the extend
// and the truncate.
return N0.getOperand(0);
}
+ // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
+ if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
+ return SDValue();
+
// Fold extract-and-trunc into a narrow extract. For example:
// i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
// i32 y = TRUNCATE(i64 x)
@@ -7054,7 +9715,6 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// we need to be more careful about the vector instructions that we generate.
if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
-
EVT VecTy = N0.getOperand(0).getValueType();
EVT ExTy = N0.getValueType();
EVT TrTy = N->getValueType(0);
@@ -7071,18 +9731,15 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
- SDValue V = DAG.getNode(ISD::BITCAST, SDLoc(N),
- NVT, N0.getOperand(0));
-
SDLoc DL(N);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT,
- DL, TrTy, V,
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
+ DAG.getBitcast(NVT, N0.getOperand(0)),
DAG.getConstant(Index, DL, IndexTy));
}
}
// trunc (select c, a, b) -> select c, (trunc a), (trunc b)
- if (N0.getOpcode() == ISD::SELECT) {
+ if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
EVT SrcVT = N0.getValueType();
if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
TLI.isTruncateFree(SrcVT, VT)) {
@@ -7094,6 +9751,26 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}
+ // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
+ if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
+ (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) &&
+ TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
+ SDValue Amt = N0.getOperand(1);
+ KnownBits Known = DAG.computeKnownBits(Amt);
+ unsigned Size = VT.getScalarSizeInBits();
+ if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) {
+ SDLoc SL(N);
+ EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
+
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
+ if (AmtVT != Amt.getValueType()) {
+ Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
+ AddToWorklist(Amt.getNode());
+ }
+ return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
+ }
+ }
+
// Fold a series of buildvector, bitcast, and truncate if possible.
// For example fold
// (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
@@ -7102,7 +9779,6 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
N0.getOperand(0).hasOneUse()) {
-
SDValue BuildVect = N0.getOperand(0);
EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
EVT TruncVecEltTy = VT.getVectorElementType();
@@ -7121,7 +9797,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
Opnds.push_back(BuildVect.getOperand(i));
- return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Opnds);
+ return DAG.getBuildVector(VT, SDLoc(N), Opnds);
}
}
@@ -7131,12 +9807,12 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// Currently we only perform this optimization on scalars because vectors
// may have different active low bits.
if (!VT.isVector()) {
- SDValue Shorter =
- GetDemandedBits(N0, APInt::getLowBitsSet(N0.getValueSizeInBits(),
- VT.getSizeInBits()));
- if (Shorter.getNode())
+ APInt Mask =
+ APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
+ if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
}
+
// fold (truncate (load x)) -> (smaller load x)
// fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
@@ -7158,6 +9834,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}
}
+
// fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
// where ... are all 'undef'.
if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
@@ -7168,7 +9845,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
SDValue X = N0.getOperand(i);
- if (X.getOpcode() != ISD::UNDEF) {
+ if (!X.isUndef()) {
V = X;
Idx = i;
NumDefs++;
@@ -7200,11 +9877,90 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}
+ // Fold truncate of a bitcast of a vector to an extract of the low vector
+ // element.
+ //
+ // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
+ if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
+ SDValue VecSrc = N0.getOperand(0);
+ EVT SrcVT = VecSrc.getValueType();
+ if (SrcVT.isVector() && SrcVT.getScalarType() == VT &&
+ (!LegalOperations ||
+ TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, SrcVT))) {
+ SDLoc SL(N);
+
+ EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
+ unsigned Idx = isLE ? 0 : SrcVT.getVectorNumElements() - 1;
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT,
+ VecSrc, DAG.getConstant(Idx, SL, IdxVT));
+ }
+ }
+
// Simplify the operands using demanded-bits information.
if (!VT.isVector() &&
SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
+ // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
+ // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
+ // When the adde's carry is not used.
+ // Don't make an illegal adde: LegalizeDAG can't expand nor promote it.
+ if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) &&
+ N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) &&
+ ((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
+ TLI.isOperationLegal(N0.getOpcode(), VT))) {
+ SDLoc SL(N);
+ auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
+ auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
+ auto VTs = DAG.getVTList(VT, N0->getValueType(1));
+ return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2));
+ }
+
+ // fold (truncate (extract_subvector(ext x))) ->
+ // (extract_subvector x)
+ // TODO: This can be generalized to cover cases where the truncate and extract
+ // do not fully cancel each other out.
+ if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getOpcode() == ISD::SIGN_EXTEND ||
+ N00.getOpcode() == ISD::ZERO_EXTEND ||
+ N00.getOpcode() == ISD::ANY_EXTEND) {
+ if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
+ VT.getVectorElementType())
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
+ N00.getOperand(0), N0.getOperand(1));
+ }
+ }
+
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
+ // Narrow a suitable binary operation with a non-opaque constant operand by
+ // moving it ahead of the truncate. This is limited to pre-legalization
+ // because targets may prefer a wider type during later combines and invert
+ // this transform.
+ switch (N0.getOpcode()) {
+ case ISD::ADD:
+ case ISD::SUB:
+ case ISD::MUL:
+ case ISD::AND:
+ case ISD::OR:
+ case ISD::XOR:
+ if (!LegalOperations && N0.hasOneUse() &&
+ (isConstantOrConstantVector(N0.getOperand(0), true) ||
+ isConstantOrConstantVector(N0.getOperand(1), true))) {
+ // TODO: We already restricted this to pre-legalization, but for vectors
+ // we are extra cautious to not create an unsupported operation.
+ // Target-specific changes are likely needed to avoid regressions here.
+ if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
+ SDLoc DL(N);
+ SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
+ SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
+ return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
+ }
+ }
+ }
+
return SDValue();
}
@@ -7222,27 +9978,28 @@ SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
LoadSDNode *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
LoadSDNode *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
+
+ // A BUILD_PAIR is always having the least significant part in elt 0 and the
+ // most significant part in elt 1. So when combining into one large load, we
+ // need to consider the endianness.
+ if (DAG.getDataLayout().isBigEndian())
+ std::swap(LD1, LD2);
+
if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !LD1->hasOneUse() ||
LD1->getAddressSpace() != LD2->getAddressSpace())
return SDValue();
EVT LD1VT = LD1->getValueType(0);
-
- if (ISD::isNON_EXTLoad(LD2) &&
- LD2->hasOneUse() &&
- // If both are volatile this would reduce the number of volatile loads.
- // If one is volatile it might be ok, but play conservative and bail out.
- !LD1->isVolatile() &&
- !LD2->isVolatile() &&
- DAG.isConsecutiveLoad(LD2, LD1, LD1VT.getSizeInBits()/8, 1)) {
+ unsigned LD1Bytes = LD1VT.getStoreSize();
+ if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() &&
+ DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) {
unsigned Align = LD1->getAlignment();
unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment(
VT.getTypeForEVT(*DAG.getContext()));
if (NewAlign <= Align &&
(!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)))
- return DAG.getLoad(VT, SDLoc(N), LD1->getChain(),
- LD1->getBasePtr(), LD1->getPointerInfo(),
- false, false, false, Align);
+ return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
+ LD1->getPointerInfo(), Align);
}
return SDValue();
@@ -7254,25 +10011,77 @@ static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
return DAG.getDataLayout().isBigEndian() ? 1 : 0;
}
+static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
+ const TargetLowering &TLI) {
+ // If this is not a bitcast to an FP type or if the target doesn't have
+ // IEEE754-compliant FP logic, we're done.
+ EVT VT = N->getValueType(0);
+ if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
+ return SDValue();
+
+ // TODO: Handle cases where the integer constant is a different scalar
+ // bitwidth to the FP.
+ SDValue N0 = N->getOperand(0);
+ EVT SourceVT = N0.getValueType();
+ if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
+ return SDValue();
+
+ unsigned FPOpcode;
+ APInt SignMask;
+ switch (N0.getOpcode()) {
+ case ISD::AND:
+ FPOpcode = ISD::FABS;
+ SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
+ break;
+ case ISD::XOR:
+ FPOpcode = ISD::FNEG;
+ SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
+ break;
+ case ISD::OR:
+ FPOpcode = ISD::FABS;
+ SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
+ break;
+ default:
+ return SDValue();
+ }
+
+ // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
+ // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
+ // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
+ // fneg (fabs X)
+ SDValue LogicOp0 = N0.getOperand(0);
+ ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
+ if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
+ LogicOp0.getOpcode() == ISD::BITCAST &&
+ LogicOp0.getOperand(0).getValueType() == VT) {
+ SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
+ NumFPLogicOpsConv++;
+ if (N0.getOpcode() == ISD::OR)
+ return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
+ return FPOp;
+ }
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitBITCAST(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
+ if (N0.isUndef())
+ return DAG.getUNDEF(VT);
+
// If the input is a BUILD_VECTOR with all constant elements, fold this now.
- // Only do this before legalize, since afterward the target may be depending
- // on the bitconvert.
+ // Only do this before legalize types, since we might create an illegal
+ // scalar type. Even if we knew we wouldn't create an illegal scalar type
+ // we can only do this before legalize ops, since the target maybe
+ // depending on the bitcast.
// First check to see if this is all constant.
if (!LegalTypes &&
N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
- VT.isVector()) {
- bool isSimple = cast<BuildVectorSDNode>(N0)->isConstant();
-
- EVT DestEltVT = N->getValueType(0).getVectorElementType();
- assert(!DestEltVT.isVector() &&
- "Element type of vector ValueType must not be vector!");
- if (isSimple)
- return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), DestEltVT);
- }
+ VT.isVector() && cast<BuildVectorSDNode>(N0)->isConstant())
+ return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
+ VT.getVectorElementType());
// If the input is a constant, let getNode fold it.
if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) {
@@ -7283,41 +10092,50 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
(isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
(isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
- TLI.isOperationLegal(ISD::Constant, VT)))
- return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, N0);
+ TLI.isOperationLegal(ISD::Constant, VT))) {
+ SDValue C = DAG.getBitcast(VT, N0);
+ if (C.getNode() != N)
+ return C;
+ }
}
// (conv (conv x, t1), t2) -> (conv x, t2)
if (N0.getOpcode() == ISD::BITCAST)
- return DAG.getNode(ISD::BITCAST, SDLoc(N), VT,
- N0.getOperand(0));
+ return DAG.getBitcast(VT, N0.getOperand(0));
// fold (conv (load x)) -> (load (conv*)x)
// If the resultant load doesn't need a higher alignment than the original!
if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
- // Do not change the width of a volatile load.
- !cast<LoadSDNode>(N0)->isVolatile() &&
// Do not remove the cast if the types differ in endian layout.
TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
- (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
+ // If the load is volatile, we only want to change the load type if the
+ // resulting load is legal. Otherwise we might increase the number of
+ // memory accesses. We don't care if the original type was legal or not
+ // as we assume software couldn't rely on the number of accesses of an
+ // illegal type.
+ ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile()) ||
+ TLI.isOperationLegal(ISD::LOAD, VT)) &&
TLI.isLoadBitCastBeneficial(N0.getValueType(), VT)) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- unsigned Align = DAG.getDataLayout().getABITypeAlignment(
- VT.getTypeForEVT(*DAG.getContext()));
unsigned OrigAlign = LN0->getAlignment();
- if (Align <= OrigAlign) {
- SDValue Load = DAG.getLoad(VT, SDLoc(N), LN0->getChain(),
- LN0->getBasePtr(), LN0->getPointerInfo(),
- LN0->isVolatile(), LN0->isNonTemporal(),
- LN0->isInvariant(), OrigAlign,
- LN0->getAAInfo());
+ bool Fast = false;
+ if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
+ LN0->getAddressSpace(), OrigAlign, &Fast) &&
+ Fast) {
+ SDValue Load =
+ DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
+ LN0->getPointerInfo(), OrigAlign,
+ LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
return Load;
}
}
+ if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
+ return V;
+
// fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
// fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
//
@@ -7334,15 +10152,14 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
(N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
N0.getNode()->hasOneUse() && VT.isInteger() &&
!VT.isVector() && !N0.getValueType().isVector()) {
- SDValue NewConv = DAG.getNode(ISD::BITCAST, SDLoc(N0), VT,
- N0.getOperand(0));
+ SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
AddToWorklist(NewConv.getNode());
SDLoc DL(N);
if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
assert(VT.getSizeInBits() == 128);
SDValue SignBit = DAG.getConstant(
- APInt::getSignBit(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
+ APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
SDValue FlipBit;
if (N0.getOpcode() == ISD::FNEG) {
FlipBit = SignBit;
@@ -7362,7 +10179,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
AddToWorklist(FlipBits.getNode());
return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
}
- APInt SignBit = APInt::getSignBit(VT.getSizeInBits());
+ APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
if (N0.getOpcode() == ISD::FNEG)
return DAG.getNode(ISD::XOR, DL, VT,
NewConv, DAG.getConstant(SignBit, DL, VT));
@@ -7385,11 +10202,10 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() &&
isa<ConstantFPSDNode>(N0.getOperand(0)) &&
VT.isInteger() && !VT.isVector()) {
- unsigned OrigXWidth = N0.getOperand(1).getValueType().getSizeInBits();
+ unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
if (isTypeLegal(IntXVT)) {
- SDValue X = DAG.getNode(ISD::BITCAST, SDLoc(N0),
- IntXVT, N0.getOperand(1));
+ SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
AddToWorklist(X.getNode());
// If X has a different width than the result/lhs, sext it or truncate it.
@@ -7411,12 +10227,10 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
}
if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
- APInt SignBit = APInt::getSignBit(VT.getSizeInBits() / 2);
- SDValue Cst = DAG.getNode(ISD::BITCAST, SDLoc(N0.getOperand(0)), VT,
- N0.getOperand(0));
+ APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
+ SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
AddToWorklist(Cst.getNode());
- SDValue X = DAG.getNode(ISD::BITCAST, SDLoc(N0.getOperand(1)), VT,
- N0.getOperand(1));
+ SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
AddToWorklist(X.getNode());
SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
AddToWorklist(XorResult.getNode());
@@ -7434,13 +10248,12 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
AddToWorklist(FlipBits.getNode());
return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
}
- APInt SignBit = APInt::getSignBit(VT.getSizeInBits());
+ APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
X = DAG.getNode(ISD::AND, SDLoc(X), VT,
X, DAG.getConstant(SignBit, SDLoc(X), VT));
AddToWorklist(X.getNode());
- SDValue Cst = DAG.getNode(ISD::BITCAST, SDLoc(N0),
- VT, N0.getOperand(0));
+ SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
AddToWorklist(Cst.getNode());
@@ -7459,7 +10272,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
// float vectors bitcast to integer vectors) into shuffles.
// bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
- N0->getOpcode() == ISD::VECTOR_SHUFFLE &&
+ N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
!(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
@@ -7470,12 +10283,15 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
if (Op.getOpcode() == ISD::BITCAST &&
Op.getOperand(0).getValueType() == VT)
return SDValue(Op.getOperand(0));
- if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
+ if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
- return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Op);
+ return DAG.getBitcast(VT, Op);
return SDValue();
};
+ // FIXME: If either input vector is bitcast, try to convert the shuffle to
+ // the result type of this bitcast. This would eliminate at least one
+ // bitcast. See the transform in InstCombine.
SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
if (!(SV0 && SV1))
@@ -7522,27 +10338,18 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
// If this is a conversion of N elements of one type to N elements of another
// type, convert each element. This handles FP<->INT cases.
if (SrcBitSize == DstBitSize) {
- EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
- BV->getValueType(0).getVectorNumElements());
-
- // Due to the FP element handling below calling this routine recursively,
- // we can end up with a scalar-to-vector node here.
- if (BV->getOpcode() == ISD::SCALAR_TO_VECTOR)
- return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(BV), VT,
- DAG.getNode(ISD::BITCAST, SDLoc(BV),
- DstEltVT, BV->getOperand(0)));
-
SmallVector<SDValue, 8> Ops;
for (SDValue Op : BV->op_values()) {
// If the vector element type is not legal, the BUILD_VECTOR operands
// are promoted and implicitly truncated. Make that explicit here.
if (Op.getValueType() != SrcEltVT)
Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
- Ops.push_back(DAG.getNode(ISD::BITCAST, SDLoc(BV),
- DstEltVT, Op));
+ Ops.push_back(DAG.getBitcast(DstEltVT, Op));
AddToWorklist(Ops.back().getNode());
}
- return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(BV), VT, Ops);
+ EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
+ BV->getValueType(0).getVectorNumElements());
+ return DAG.getBuildVector(VT, SDLoc(BV), Ops);
}
// Otherwise, we're growing or shrinking the elements. To avoid having to
@@ -7584,7 +10391,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
// Shift the previously computed bits over.
NewBits <<= SrcBitSize;
SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j));
- if (Op.getOpcode() == ISD::UNDEF) continue;
+ if (Op.isUndef()) continue;
EltIsUndef = false;
NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue().
@@ -7598,7 +10405,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
}
EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
- return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
+ return DAG.getBuildVector(VT, DL, Ops);
}
// Finally, this must be the case where we are shrinking elements: each input
@@ -7609,7 +10416,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
SmallVector<SDValue, 8> Ops;
for (const SDValue &Op : BV->op_values()) {
- if (Op.getOpcode() == ISD::UNDEF) {
+ if (Op.isUndef()) {
Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT));
continue;
}
@@ -7620,7 +10427,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
APInt ThisVal = OpVal.trunc(DstBitSize);
Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
- OpVal = OpVal.lshr(DstBitSize);
+ OpVal.lshrInPlace(DstBitSize);
}
// For big endian targets, swap the order of the pieces of each element.
@@ -7628,7 +10435,12 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
std::reverse(Ops.end()-NumOutputsPerInput, Ops.end());
}
- return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
+ return DAG.getBuildVector(VT, DL, Ops);
+}
+
+static bool isContractable(SDNode *N) {
+ SDNodeFlags F = N->getFlags();
+ return F.hasAllowContract() || F.hasAllowReassociation();
}
/// Try to perform FMA combining on a given FADD node.
@@ -7639,173 +10451,202 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDLoc SL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- bool AllowFusion =
- (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
// Floating-point multiply-add with intermediate rounding.
bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
// Floating-point multiply-add without intermediate rounding.
bool HasFMA =
- AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+ TLI.isFMAFasterThanFMulAndFAdd(VT) &&
(!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
// No valid opcode, do not combine.
if (!HasFMAD && !HasFMA)
return SDValue();
+ SDNodeFlags Flags = N->getFlags();
+ bool CanFuse = Options.UnsafeFPMath || isContractable(N);
+ bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
+ CanFuse || HasFMAD);
+ // If the addition is not contractable, do not combine.
+ if (!AllowFusionGlobally && !isContractable(N))
+ return SDValue();
+
+ const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
+ if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
+ return SDValue();
+
// Always prefer FMAD to FMA for precision.
unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
- bool LookThroughFPExt = TLI.isFPExtFree(VT);
+ // Is the node an FMUL and contractable either due to global flags or
+ // SDNodeFlags.
+ auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
+ if (N.getOpcode() != ISD::FMUL)
+ return false;
+ return AllowFusionGlobally || isContractable(N.getNode());
+ };
// If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
// prefer to fold the multiply with fewer uses.
- if (Aggressive && N0.getOpcode() == ISD::FMUL &&
- N1.getOpcode() == ISD::FMUL) {
+ if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
if (N0.getNode()->use_size() > N1.getNode()->use_size())
std::swap(N0, N1);
}
// fold (fadd (fmul x, y), z) -> (fma x, y, z)
- if (N0.getOpcode() == ISD::FMUL &&
- (Aggressive || N0->hasOneUse())) {
+ if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
- N0.getOperand(0), N0.getOperand(1), N1);
+ N0.getOperand(0), N0.getOperand(1), N1, Flags);
}
// fold (fadd x, (fmul y, z)) -> (fma y, z, x)
// Note: Commutes FADD operands.
- if (N1.getOpcode() == ISD::FMUL &&
- (Aggressive || N1->hasOneUse())) {
+ if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
- N1.getOperand(0), N1.getOperand(1), N0);
+ N1.getOperand(0), N1.getOperand(1), N0, Flags);
}
// Look through FP_EXTEND nodes to do more combining.
- if (AllowFusion && LookThroughFPExt) {
- // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
- if (N0.getOpcode() == ISD::FP_EXTEND) {
- SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == ISD::FMUL)
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N00.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N00.getOperand(1)), N1);
+
+ // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
+ if (N0.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N00 = N0.getOperand(0);
+ if (isContractableFMUL(N00) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) {
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N00.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N00.getOperand(1)), N1, Flags);
}
+ }
- // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
- // Note: Commutes FADD operands.
- if (N1.getOpcode() == ISD::FP_EXTEND) {
- SDValue N10 = N1.getOperand(0);
- if (N10.getOpcode() == ISD::FMUL)
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N10.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N10.getOperand(1)), N0);
+ // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
+ // Note: Commutes FADD operands.
+ if (N1.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N10 = N1.getOperand(0);
+ if (isContractableFMUL(N10) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) {
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N10.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N10.getOperand(1)), N0, Flags);
}
}
// More folding opportunities when target permits.
- if ((AllowFusion || HasFMAD) && Aggressive) {
+ if (Aggressive) {
// fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
- if (N0.getOpcode() == PreferredFusedOpcode &&
- N0.getOperand(2).getOpcode() == ISD::FMUL) {
+ if (CanFuse &&
+ N0.getOpcode() == PreferredFusedOpcode &&
+ N0.getOperand(2).getOpcode() == ISD::FMUL &&
+ N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(2).getOperand(0),
N0.getOperand(2).getOperand(1),
- N1));
+ N1, Flags), Flags);
}
// fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
- if (N1->getOpcode() == PreferredFusedOpcode &&
- N1.getOperand(2).getOpcode() == ISD::FMUL) {
+ if (CanFuse &&
+ N1->getOpcode() == PreferredFusedOpcode &&
+ N1.getOperand(2).getOpcode() == ISD::FMUL &&
+ N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N1.getOperand(0), N1.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
N1.getOperand(2).getOperand(0),
N1.getOperand(2).getOperand(1),
- N0));
+ N0, Flags), Flags);
}
- if (AllowFusion && LookThroughFPExt) {
- // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
- // -> (fma x, y, (fma (fpext u), (fpext v), z))
- auto FoldFAddFMAFPExtFMul = [&] (
- SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) {
- return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
- DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
- DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
- Z));
- };
- if (N0.getOpcode() == PreferredFusedOpcode) {
- SDValue N02 = N0.getOperand(2);
- if (N02.getOpcode() == ISD::FP_EXTEND) {
- SDValue N020 = N02.getOperand(0);
- if (N020.getOpcode() == ISD::FMUL)
- return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
- N020.getOperand(0), N020.getOperand(1),
- N1);
+
+ // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
+ // -> (fma x, y, (fma (fpext u), (fpext v), z))
+ auto FoldFAddFMAFPExtFMul = [&] (
+ SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z,
+ SDNodeFlags Flags) {
+ return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
+ DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
+ Z, Flags), Flags);
+ };
+ if (N0.getOpcode() == PreferredFusedOpcode) {
+ SDValue N02 = N0.getOperand(2);
+ if (N02.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N020 = N02.getOperand(0);
+ if (isContractableFMUL(N020) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) {
+ return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
+ N020.getOperand(0), N020.getOperand(1),
+ N1, Flags);
}
}
+ }
- // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
- // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
- // FIXME: This turns two single-precision and one double-precision
- // operation into two double-precision operations, which might not be
- // interesting for all targets, especially GPUs.
- auto FoldFAddFPExtFMAFMul = [&] (
- SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) {
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
- DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
- DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
- DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
- Z));
- };
- if (N0.getOpcode() == ISD::FP_EXTEND) {
- SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == PreferredFusedOpcode) {
- SDValue N002 = N00.getOperand(2);
- if (N002.getOpcode() == ISD::FMUL)
- return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
- N002.getOperand(0), N002.getOperand(1),
- N1);
+ // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
+ // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
+ // FIXME: This turns two single-precision and one double-precision
+ // operation into two double-precision operations, which might not be
+ // interesting for all targets, especially GPUs.
+ auto FoldFAddFPExtFMAFMul = [&] (
+ SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z,
+ SDNodeFlags Flags) {
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
+ DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
+ Z, Flags), Flags);
+ };
+ if (N0.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getOpcode() == PreferredFusedOpcode) {
+ SDValue N002 = N00.getOperand(2);
+ if (isContractableFMUL(N002) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) {
+ return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
+ N002.getOperand(0), N002.getOperand(1),
+ N1, Flags);
}
}
+ }
- // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
- // -> (fma y, z, (fma (fpext u), (fpext v), x))
- if (N1.getOpcode() == PreferredFusedOpcode) {
- SDValue N12 = N1.getOperand(2);
- if (N12.getOpcode() == ISD::FP_EXTEND) {
- SDValue N120 = N12.getOperand(0);
- if (N120.getOpcode() == ISD::FMUL)
- return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
- N120.getOperand(0), N120.getOperand(1),
- N0);
+ // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
+ // -> (fma y, z, (fma (fpext u), (fpext v), x))
+ if (N1.getOpcode() == PreferredFusedOpcode) {
+ SDValue N12 = N1.getOperand(2);
+ if (N12.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N120 = N12.getOperand(0);
+ if (isContractableFMUL(N120) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N120.getValueType())) {
+ return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
+ N120.getOperand(0), N120.getOperand(1),
+ N0, Flags);
}
}
+ }
- // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
- // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
- // FIXME: This turns two single-precision and one double-precision
- // operation into two double-precision operations, which might not be
- // interesting for all targets, especially GPUs.
- if (N1.getOpcode() == ISD::FP_EXTEND) {
- SDValue N10 = N1.getOperand(0);
- if (N10.getOpcode() == PreferredFusedOpcode) {
- SDValue N102 = N10.getOperand(2);
- if (N102.getOpcode() == ISD::FMUL)
- return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
- N102.getOperand(0), N102.getOperand(1),
- N0);
+ // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
+ // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
+ // FIXME: This turns two single-precision and one double-precision
+ // operation into two double-precision operations, which might not be
+ // interesting for all targets, especially GPUs.
+ if (N1.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N10 = N1.getOperand(0);
+ if (N10.getOpcode() == PreferredFusedOpcode) {
+ SDValue N102 = N10.getOperand(2);
+ if (isContractableFMUL(N102) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) {
+ return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
+ N102.getOperand(0), N102.getOperand(1),
+ N0, Flags);
}
}
}
@@ -7822,149 +10663,169 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
SDLoc SL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- bool AllowFusion =
- (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
-
// Floating-point multiply-add with intermediate rounding.
bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
// Floating-point multiply-add without intermediate rounding.
bool HasFMA =
- AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+ TLI.isFMAFasterThanFMulAndFAdd(VT) &&
(!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
// No valid opcode, do not combine.
if (!HasFMAD && !HasFMA)
return SDValue();
+ const SDNodeFlags Flags = N->getFlags();
+ bool CanFuse = Options.UnsafeFPMath || isContractable(N);
+ bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
+ CanFuse || HasFMAD);
+
+ // If the subtraction is not contractable, do not combine.
+ if (!AllowFusionGlobally && !isContractable(N))
+ return SDValue();
+
+ const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
+ if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
+ return SDValue();
+
// Always prefer FMAD to FMA for precision.
unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
- bool LookThroughFPExt = TLI.isFPExtFree(VT);
+
+ // Is the node an FMUL and contractable either due to global flags or
+ // SDNodeFlags.
+ auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
+ if (N.getOpcode() != ISD::FMUL)
+ return false;
+ return AllowFusionGlobally || isContractable(N.getNode());
+ };
// fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
- if (N0.getOpcode() == ISD::FMUL &&
- (Aggressive || N0->hasOneUse())) {
+ if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1),
- DAG.getNode(ISD::FNEG, SL, VT, N1));
+ DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
}
// fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
// Note: Commutes FSUB operands.
- if (N1.getOpcode() == ISD::FMUL &&
- (Aggressive || N1->hasOneUse()))
+ if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FNEG, SL, VT,
N1.getOperand(0)),
- N1.getOperand(1), N0);
+ N1.getOperand(1), N0, Flags);
+ }
// fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
- if (N0.getOpcode() == ISD::FNEG &&
- N0.getOperand(0).getOpcode() == ISD::FMUL &&
+ if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
(Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
SDValue N00 = N0.getOperand(0).getOperand(0);
SDValue N01 = N0.getOperand(0).getOperand(1);
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
- DAG.getNode(ISD::FNEG, SL, VT, N1));
+ DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
}
// Look through FP_EXTEND nodes to do more combining.
- if (AllowFusion && LookThroughFPExt) {
- // fold (fsub (fpext (fmul x, y)), z)
- // -> (fma (fpext x), (fpext y), (fneg z))
- if (N0.getOpcode() == ISD::FP_EXTEND) {
- SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == ISD::FMUL)
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N00.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N00.getOperand(1)),
- DAG.getNode(ISD::FNEG, SL, VT, N1));
+
+ // fold (fsub (fpext (fmul x, y)), z)
+ // -> (fma (fpext x), (fpext y), (fneg z))
+ if (N0.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N00 = N0.getOperand(0);
+ if (isContractableFMUL(N00) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) {
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N00.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N00.getOperand(1)),
+ DAG.getNode(ISD::FNEG, SL, VT, N1), Flags);
}
+ }
- // fold (fsub x, (fpext (fmul y, z)))
- // -> (fma (fneg (fpext y)), (fpext z), x)
- // Note: Commutes FSUB operands.
- if (N1.getOpcode() == ISD::FP_EXTEND) {
- SDValue N10 = N1.getOperand(0);
- if (N10.getOpcode() == ISD::FMUL)
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT,
+ // fold (fsub x, (fpext (fmul y, z)))
+ // -> (fma (fneg (fpext y)), (fpext z), x)
+ // Note: Commutes FSUB operands.
+ if (N1.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N10 = N1.getOperand(0);
+ if (isContractableFMUL(N10) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) {
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N10.getOperand(0))),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N10.getOperand(1)),
+ N0, Flags);
+ }
+ }
+
+ // fold (fsub (fpext (fneg (fmul, x, y))), z)
+ // -> (fneg (fma (fpext x), (fpext y), z))
+ // Note: This could be removed with appropriate canonicalization of the
+ // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
+ // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
+ // from implementing the canonicalization in visitFSUB.
+ if (N0.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getOpcode() == ISD::FNEG) {
+ SDValue N000 = N00.getOperand(0);
+ if (isContractableFMUL(N000) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) {
+ return DAG.getNode(ISD::FNEG, SL, VT,
+ DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N10.getOperand(0))),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N10.getOperand(1)),
- N0);
- }
-
- // fold (fsub (fpext (fneg (fmul, x, y))), z)
- // -> (fneg (fma (fpext x), (fpext y), z))
- // Note: This could be removed with appropriate canonicalization of the
- // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
- // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
- // from implementing the canonicalization in visitFSUB.
- if (N0.getOpcode() == ISD::FP_EXTEND) {
- SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == ISD::FNEG) {
- SDValue N000 = N00.getOperand(0);
- if (N000.getOpcode() == ISD::FMUL) {
- return DAG.getNode(ISD::FNEG, SL, VT,
- DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N000.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N000.getOperand(1)),
- N1));
- }
+ N000.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N000.getOperand(1)),
+ N1, Flags));
}
}
+ }
- // fold (fsub (fneg (fpext (fmul, x, y))), z)
- // -> (fneg (fma (fpext x)), (fpext y), z)
- // Note: This could be removed with appropriate canonicalization of the
- // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
- // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
- // from implementing the canonicalization in visitFSUB.
- if (N0.getOpcode() == ISD::FNEG) {
- SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == ISD::FP_EXTEND) {
- SDValue N000 = N00.getOperand(0);
- if (N000.getOpcode() == ISD::FMUL) {
- return DAG.getNode(ISD::FNEG, SL, VT,
- DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N000.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N000.getOperand(1)),
- N1));
- }
+ // fold (fsub (fneg (fpext (fmul, x, y))), z)
+ // -> (fneg (fma (fpext x)), (fpext y), z)
+ // Note: This could be removed with appropriate canonicalization of the
+ // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
+ // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
+ // from implementing the canonicalization in visitFSUB.
+ if (N0.getOpcode() == ISD::FNEG) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N000 = N00.getOperand(0);
+ if (isContractableFMUL(N000) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N000.getValueType())) {
+ return DAG.getNode(ISD::FNEG, SL, VT,
+ DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N000.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N000.getOperand(1)),
+ N1, Flags));
}
}
-
}
// More folding opportunities when target permits.
- if ((AllowFusion || HasFMAD) && Aggressive) {
+ if (Aggressive) {
// fold (fsub (fma x, y, (fmul u, v)), z)
// -> (fma x, y (fma u, v, (fneg z)))
- if (N0.getOpcode() == PreferredFusedOpcode &&
- N0.getOperand(2).getOpcode() == ISD::FMUL) {
+ if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
+ isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() &&
+ N0.getOperand(2)->hasOneUse()) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(2).getOperand(0),
N0.getOperand(2).getOperand(1),
DAG.getNode(ISD::FNEG, SL, VT,
- N1)));
+ N1), Flags), Flags);
}
// fold (fsub x, (fma y, z, (fmul u, v)))
// -> (fma (fneg y), z, (fma (fneg u), v, x))
- if (N1.getOpcode() == PreferredFusedOpcode &&
- N1.getOperand(2).getOpcode() == ISD::FMUL) {
+ if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
+ isContractableFMUL(N1.getOperand(2))) {
SDValue N20 = N1.getOperand(2).getOperand(0);
SDValue N21 = N1.getOperand(2).getOperand(1);
return DAG.getNode(PreferredFusedOpcode, SL, VT,
@@ -7973,132 +10834,146 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
N1.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FNEG, SL, VT, N20),
+ N21, N0, Flags), Flags);
+ }
- N21, N0));
- }
-
- if (AllowFusion && LookThroughFPExt) {
- // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
- // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
- if (N0.getOpcode() == PreferredFusedOpcode) {
- SDValue N02 = N0.getOperand(2);
- if (N02.getOpcode() == ISD::FP_EXTEND) {
- SDValue N020 = N02.getOperand(0);
- if (N020.getOpcode() == ISD::FMUL)
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- N0.getOperand(0), N0.getOperand(1),
- DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N020.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N020.getOperand(1)),
- DAG.getNode(ISD::FNEG, SL, VT,
- N1)));
- }
- }
-
- // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
- // -> (fma (fpext x), (fpext y),
- // (fma (fpext u), (fpext v), (fneg z)))
- // FIXME: This turns two single-precision and one double-precision
- // operation into two double-precision operations, which might not be
- // interesting for all targets, especially GPUs.
- if (N0.getOpcode() == ISD::FP_EXTEND) {
- SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == PreferredFusedOpcode) {
- SDValue N002 = N00.getOperand(2);
- if (N002.getOpcode() == ISD::FMUL)
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N00.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N00.getOperand(1)),
- DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N002.getOperand(0)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N002.getOperand(1)),
- DAG.getNode(ISD::FNEG, SL, VT,
- N1)));
- }
- }
- // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
- // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
- if (N1.getOpcode() == PreferredFusedOpcode &&
- N1.getOperand(2).getOpcode() == ISD::FP_EXTEND) {
- SDValue N120 = N1.getOperand(2).getOperand(0);
- if (N120.getOpcode() == ISD::FMUL) {
- SDValue N1200 = N120.getOperand(0);
- SDValue N1201 = N120.getOperand(1);
+ // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
+ // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
+ if (N0.getOpcode() == PreferredFusedOpcode) {
+ SDValue N02 = N0.getOperand(2);
+ if (N02.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N020 = N02.getOperand(0);
+ if (isContractableFMUL(N020) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
- N1.getOperand(1),
+ N0.getOperand(0), N0.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL,
- VT, N1200)),
DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N1201),
- N0));
+ N020.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N020.getOperand(1)),
+ DAG.getNode(ISD::FNEG, SL, VT,
+ N1), Flags), Flags);
}
}
+ }
- // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
- // -> (fma (fneg (fpext y)), (fpext z),
- // (fma (fneg (fpext u)), (fpext v), x))
- // FIXME: This turns two single-precision and one double-precision
- // operation into two double-precision operations, which might not be
- // interesting for all targets, especially GPUs.
- if (N1.getOpcode() == ISD::FP_EXTEND &&
- N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
- SDValue N100 = N1.getOperand(0).getOperand(0);
- SDValue N101 = N1.getOperand(0).getOperand(1);
- SDValue N102 = N1.getOperand(0).getOperand(2);
- if (N102.getOpcode() == ISD::FMUL) {
- SDValue N1020 = N102.getOperand(0);
- SDValue N1021 = N102.getOperand(1);
+ // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
+ // -> (fma (fpext x), (fpext y),
+ // (fma (fpext u), (fpext v), (fneg z)))
+ // FIXME: This turns two single-precision and one double-precision
+ // operation into two double-precision operations, which might not be
+ // interesting for all targets, especially GPUs.
+ if (N0.getOpcode() == ISD::FP_EXTEND) {
+ SDValue N00 = N0.getOperand(0);
+ if (N00.getOpcode() == PreferredFusedOpcode) {
+ SDValue N002 = N00.getOperand(2);
+ if (isContractableFMUL(N002) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N100)),
- DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N00.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N00.getOperand(1)),
DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT,
- DAG.getNode(ISD::FP_EXTEND, SL,
- VT, N1020)),
DAG.getNode(ISD::FP_EXTEND, SL, VT,
- N1021),
- N0));
+ N002.getOperand(0)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N002.getOperand(1)),
+ DAG.getNode(ISD::FNEG, SL, VT,
+ N1), Flags), Flags);
}
}
}
+
+ // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
+ // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
+ if (N1.getOpcode() == PreferredFusedOpcode &&
+ N1.getOperand(2).getOpcode() == ISD::FP_EXTEND) {
+ SDValue N120 = N1.getOperand(2).getOperand(0);
+ if (isContractableFMUL(N120) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N120.getValueType())) {
+ SDValue N1200 = N120.getOperand(0);
+ SDValue N1201 = N120.getOperand(1);
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
+ N1.getOperand(1),
+ DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL,
+ VT, N1200)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N1201),
+ N0, Flags), Flags);
+ }
+ }
+
+ // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
+ // -> (fma (fneg (fpext y)), (fpext z),
+ // (fma (fneg (fpext u)), (fpext v), x))
+ // FIXME: This turns two single-precision and one double-precision
+ // operation into two double-precision operations, which might not be
+ // interesting for all targets, especially GPUs.
+ if (N1.getOpcode() == ISD::FP_EXTEND &&
+ N1.getOperand(0).getOpcode() == PreferredFusedOpcode) {
+ SDValue CvtSrc = N1.getOperand(0);
+ SDValue N100 = CvtSrc.getOperand(0);
+ SDValue N101 = CvtSrc.getOperand(1);
+ SDValue N102 = CvtSrc.getOperand(2);
+ if (isContractableFMUL(N102) &&
+ TLI.isFPExtFoldable(PreferredFusedOpcode, VT, CvtSrc.getValueType())) {
+ SDValue N1020 = N102.getOperand(0);
+ SDValue N1021 = N102.getOperand(1);
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N100)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
+ DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT,
+ DAG.getNode(ISD::FP_EXTEND, SL,
+ VT, N1020)),
+ DAG.getNode(ISD::FP_EXTEND, SL, VT,
+ N1021),
+ N0, Flags), Flags);
+ }
+ }
}
return SDValue();
}
-/// Try to perform FMA combining on a given FMUL node.
-SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) {
+/// Try to perform FMA combining on a given FMUL node based on the distributive
+/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
+/// subtraction instead of addition).
+SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
SDLoc SL(N);
+ const SDNodeFlags Flags = N->getFlags();
assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
const TargetOptions &Options = DAG.getTarget().Options;
- bool AllowFusion =
- (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
- // Floating-point multiply-add with intermediate rounding.
- bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
+ // The transforms below are incorrect when x == 0 and y == inf, because the
+ // intermediate multiplication produces a nan.
+ if (!Options.NoInfsFPMath)
+ return SDValue();
// Floating-point multiply-add without intermediate rounding.
bool HasFMA =
- AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+ (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
+ TLI.isFMAFasterThanFMulAndFAdd(VT) &&
(!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
+ // Floating-point multiply-add with intermediate rounding. This can result
+ // in a less precise result due to the changed rounding order.
+ bool HasFMAD = Options.UnsafeFPMath &&
+ (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
+
// No valid opcode, do not combine.
if (!HasFMAD && !HasFMA)
return SDValue();
@@ -8107,54 +10982,58 @@ SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) {
unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
- // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y)
- // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y))
- auto FuseFADD = [&](SDValue X, SDValue Y) {
+ // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
+ // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
+ auto FuseFADD = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) {
if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
- auto XC1 = isConstOrConstSplatFP(X.getOperand(1));
- if (XC1 && XC1->isExactlyValue(+1.0))
- return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y);
- if (XC1 && XC1->isExactlyValue(-1.0))
- return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
- DAG.getNode(ISD::FNEG, SL, VT, Y));
+ if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
+ if (C->isExactlyValue(+1.0))
+ return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+ Y, Flags);
+ if (C->isExactlyValue(-1.0))
+ return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+ DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
+ }
}
return SDValue();
};
- if (SDValue FMA = FuseFADD(N0, N1))
+ if (SDValue FMA = FuseFADD(N0, N1, Flags))
return FMA;
- if (SDValue FMA = FuseFADD(N1, N0))
+ if (SDValue FMA = FuseFADD(N1, N0, Flags))
return FMA;
- // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y)
- // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y))
- // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y))
- // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y)
- auto FuseFSUB = [&](SDValue X, SDValue Y) {
+ // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
+ // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
+ // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
+ // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
+ auto FuseFSUB = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) {
if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
- auto XC0 = isConstOrConstSplatFP(X.getOperand(0));
- if (XC0 && XC0->isExactlyValue(+1.0))
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
- Y);
- if (XC0 && XC0->isExactlyValue(-1.0))
- return DAG.getNode(PreferredFusedOpcode, SL, VT,
- DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
- DAG.getNode(ISD::FNEG, SL, VT, Y));
-
- auto XC1 = isConstOrConstSplatFP(X.getOperand(1));
- if (XC1 && XC1->isExactlyValue(+1.0))
- return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
- DAG.getNode(ISD::FNEG, SL, VT, Y));
- if (XC1 && XC1->isExactlyValue(-1.0))
- return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y);
+ if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
+ if (C0->isExactlyValue(+1.0))
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
+ Y, Flags);
+ if (C0->isExactlyValue(-1.0))
+ return DAG.getNode(PreferredFusedOpcode, SL, VT,
+ DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
+ DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
+ }
+ if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
+ if (C1->isExactlyValue(+1.0))
+ return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+ DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);
+ if (C1->isExactlyValue(-1.0))
+ return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+ Y, Flags);
+ }
}
return SDValue();
};
- if (SDValue FMA = FuseFSUB(N0, N1))
+ if (SDValue FMA = FuseFSUB(N0, N1, Flags))
return FMA;
- if (SDValue FMA = FuseFSUB(N1, N0))
+ if (SDValue FMA = FuseFSUB(N1, N0, Flags))
return FMA;
return SDValue();
@@ -8168,7 +11047,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ const SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector())
@@ -8183,6 +11062,15 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
if (N0CFP && !N1CFP)
return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags);
+ // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
+ ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
+ if (N1C && N1C->isZero())
+ if (N1C->isNegative() || Options.UnsafeFPMath || Flags.hasNoSignedZeros())
+ return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (fadd A, (fneg B)) -> (fsub A, B)
if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
isNegatibleForFree(N1, LegalOperations, TLI, &Options) == 2)
@@ -8195,32 +11083,53 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
return DAG.getNode(ISD::FSUB, DL, VT, N1,
GetNegatedExpression(N0, DAG, LegalOperations), Flags);
- // If 'unsafe math' is enabled, fold lots of things.
- if (Options.UnsafeFPMath) {
- // No FP constant should be created after legalization as Instruction
- // Selection pass has a hard time dealing with FP constants.
- bool AllowNewConst = (Level < AfterLegalizeDAG);
+ auto isFMulNegTwo = [](SDValue FMul) {
+ if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
+ return false;
+ auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
+ return C && C->isExactlyValue(-2.0);
+ };
- // fold (fadd A, 0) -> A
- if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1))
- if (N1C->isZero())
- return N0;
+ // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
+ if (isFMulNegTwo(N0)) {
+ SDValue B = N0.getOperand(0);
+ SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags);
+ return DAG.getNode(ISD::FSUB, DL, VT, N1, Add, Flags);
+ }
+ // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
+ if (isFMulNegTwo(N1)) {
+ SDValue B = N1.getOperand(0);
+ SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags);
+ return DAG.getNode(ISD::FSUB, DL, VT, N0, Add, Flags);
+ }
- // fold (fadd (fadd x, c1), c2) -> (fadd x, (fadd c1, c2))
- if (N1CFP && N0.getOpcode() == ISD::FADD && N0.getNode()->hasOneUse() &&
- isConstantFPBuildVectorOrConstantFP(N0.getOperand(1)))
- return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0),
- DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1,
- Flags),
- Flags);
+ // No FP constant should be created after legalization as Instruction
+ // Selection pass has a hard time dealing with FP constants.
+ bool AllowNewConst = (Level < AfterLegalizeDAG);
+ // If 'unsafe math' or nnan is enabled, fold lots of things.
+ if ((Options.UnsafeFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
// If allowed, fold (fadd (fneg x), x) -> 0.0
- if (AllowNewConst && N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
+ if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
return DAG.getConstantFP(0.0, DL, VT);
// If allowed, fold (fadd x, (fneg x)) -> 0.0
- if (AllowNewConst && N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
+ if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
return DAG.getConstantFP(0.0, DL, VT);
+ }
+
+ // If 'unsafe math' or reassoc and nsz, fold lots of things.
+ // TODO: break out portions of the transformations below for which Unsafe is
+ // considered and which do not require both nsz and reassoc
+ if ((Options.UnsafeFPMath ||
+ (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
+ AllowNewConst) {
+ // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
+ if (N1CFP && N0.getOpcode() == ISD::FADD &&
+ isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
+ SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, Flags);
+ return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC, Flags);
+ }
// We can fold chains of FADD's of the same value into multiplications.
// This transform is not safe in general because we are reducing the number
@@ -8268,7 +11177,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
}
}
- if (N0.getOpcode() == ISD::FADD && AllowNewConst) {
+ if (N0.getOpcode() == ISD::FADD) {
bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
// (fadd (fadd x, x), x) -> (fmul x, 3.0)
if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
@@ -8278,7 +11187,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
}
}
- if (N1.getOpcode() == ISD::FADD && AllowNewConst) {
+ if (N1.getOpcode() == ISD::FADD) {
bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
// (fadd x, (fadd x, x)) -> (fmul x, 3.0)
if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
@@ -8289,8 +11198,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
}
// (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
- if (AllowNewConst &&
- N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
+ if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
N0.getOperand(0) == N0.getOperand(1) &&
N1.getOperand(0) == N1.getOperand(1) &&
N0.getOperand(0) == N1.getOperand(0)) {
@@ -8305,19 +11213,18 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
AddToWorklist(Fused.getNode());
return Fused;
}
-
return SDValue();
}
SDValue DAGCombiner::visitFSUB(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
- ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
- ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
+ ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
+ ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
EVT VT = N->getValueType(0);
- SDLoc dl(N);
+ SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ const SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector())
@@ -8326,45 +11233,52 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
// fold (fsub c1, c2) -> c1-c2
if (N0CFP && N1CFP)
- return DAG.getNode(ISD::FSUB, dl, VT, N0, N1, Flags);
+ return DAG.getNode(ISD::FSUB, DL, VT, N0, N1, Flags);
- // fold (fsub A, (fneg B)) -> (fadd A, B)
- if (isNegatibleForFree(N1, LegalOperations, TLI, &Options))
- return DAG.getNode(ISD::FADD, dl, VT, N0,
- GetNegatedExpression(N1, DAG, LegalOperations), Flags);
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
- // If 'unsafe math' is enabled, fold lots of things.
- if (Options.UnsafeFPMath) {
- // (fsub A, 0) -> A
- if (N1CFP && N1CFP->isZero())
+ // (fsub A, 0) -> A
+ if (N1CFP && N1CFP->isZero()) {
+ if (!N1CFP->isNegative() || Options.UnsafeFPMath ||
+ Flags.hasNoSignedZeros()) {
return N0;
+ }
+ }
+
+ if (N0 == N1) {
+ // (fsub x, x) -> 0.0
+ if (Options.UnsafeFPMath || Flags.hasNoNaNs())
+ return DAG.getConstantFP(0.0f, DL, VT);
+ }
- // (fsub 0, B) -> -B
- if (N0CFP && N0CFP->isZero()) {
+ // (fsub -0.0, N1) -> -N1
+ if (N0CFP && N0CFP->isZero()) {
+ if (N0CFP->isNegative() ||
+ (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
if (isNegatibleForFree(N1, LegalOperations, TLI, &Options))
return GetNegatedExpression(N1, DAG, LegalOperations);
if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
- return DAG.getNode(ISD::FNEG, dl, VT, N1);
+ return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags);
}
+ }
- // (fsub x, x) -> 0.0
- if (N0 == N1)
- return DAG.getConstantFP(0.0f, dl, VT);
-
- // (fsub x, (fadd x, y)) -> (fneg y)
- // (fsub x, (fadd y, x)) -> (fneg y)
- if (N1.getOpcode() == ISD::FADD) {
- SDValue N10 = N1->getOperand(0);
- SDValue N11 = N1->getOperand(1);
-
- if (N10 == N0 && isNegatibleForFree(N11, LegalOperations, TLI, &Options))
- return GetNegatedExpression(N11, DAG, LegalOperations);
-
- if (N11 == N0 && isNegatibleForFree(N10, LegalOperations, TLI, &Options))
- return GetNegatedExpression(N10, DAG, LegalOperations);
- }
+ if ((Options.UnsafeFPMath ||
+ (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))
+ && N1.getOpcode() == ISD::FADD) {
+ // X - (X + Y) -> -Y
+ if (N0 == N1->getOperand(0))
+ return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1), Flags);
+ // X - (Y + X) -> -Y
+ if (N0 == N1->getOperand(1))
+ return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0), Flags);
}
+ // fold (fsub A, (fneg B)) -> (fadd A, B)
+ if (isNegatibleForFree(N1, LegalOperations, TLI, &Options))
+ return DAG.getNode(ISD::FADD, DL, VT, N0,
+ GetNegatedExpression(N1, DAG, LegalOperations), Flags);
+
// FSUB -> FMA combines:
if (SDValue Fused = visitFSUBForFMACombine(N)) {
AddToWorklist(Fused.getNode());
@@ -8377,12 +11291,12 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
SDValue DAGCombiner::visitFMUL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
- ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
- ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
+ ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
+ ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ const SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector()) {
@@ -8404,42 +11318,35 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
if (N1CFP && N1CFP->isExactlyValue(1.0))
return N0;
- if (Options.UnsafeFPMath) {
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
+ if (Options.UnsafeFPMath ||
+ (Flags.hasNoNaNs() && Flags.hasNoSignedZeros())) {
// fold (fmul A, 0) -> 0
if (N1CFP && N1CFP->isZero())
return N1;
+ }
- // fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2))
- if (N0.getOpcode() == ISD::FMUL) {
- // Fold scalars or any vector constants (not just splats).
- // This fold is done in general by InstCombine, but extra fmul insts
- // may have been generated during lowering.
+ if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
+ // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
+ if (isConstantFPBuildVectorOrConstantFP(N1) &&
+ N0.getOpcode() == ISD::FMUL) {
SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
- auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
- auto *BV00 = dyn_cast<BuildVectorSDNode>(N00);
- auto *BV01 = dyn_cast<BuildVectorSDNode>(N01);
-
- // Check 1: Make sure that the first operand of the inner multiply is NOT
- // a constant. Otherwise, we may induce infinite looping.
- if (!(isConstOrConstSplatFP(N00) || (BV00 && BV00->isConstant()))) {
- // Check 2: Make sure that the second operand of the inner multiply and
- // the second operand of the outer multiply are constants.
- if ((N1CFP && isConstOrConstSplatFP(N01)) ||
- (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) {
- SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags);
- return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags);
- }
+ // Avoid an infinite loop by making sure that N00 is not a constant
+ // (the inner multiply has not been constant folded yet).
+ if (isConstantFPBuildVectorOrConstantFP(N01) &&
+ !isConstantFPBuildVectorOrConstantFP(N00)) {
+ SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags);
+ return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags);
}
}
- // fold (fmul (fadd x, x), c) -> (fmul x, (fmul 2.0, c))
- // Undo the fmul 2.0, x -> fadd x, x transformation, since if it occurs
- // during an early run of DAGCombiner can prevent folding with fmuls
- // inserted during lowering.
- if (N0.getOpcode() == ISD::FADD &&
- (N0.getOperand(0) == N0.getOperand(1)) &&
- N0.hasOneUse()) {
+ // Match a special-case: we convert X * 2.0 into fadd.
+ // fmul (fadd X, X), C -> fmul X, 2.0 * C
+ if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
+ N0.getOperand(0) == N0.getOperand(1)) {
const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags);
return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags);
@@ -8468,8 +11375,54 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
}
}
+ // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
+ // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
+ if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
+ (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
+ TLI.isOperationLegal(ISD::FABS, VT)) {
+ SDValue Select = N0, X = N1;
+ if (Select.getOpcode() != ISD::SELECT)
+ std::swap(Select, X);
+
+ SDValue Cond = Select.getOperand(0);
+ auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
+ auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
+
+ if (TrueOpnd && FalseOpnd &&
+ Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
+ isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
+ cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
+ switch (CC) {
+ default: break;
+ case ISD::SETOLT:
+ case ISD::SETULT:
+ case ISD::SETOLE:
+ case ISD::SETULE:
+ case ISD::SETLT:
+ case ISD::SETLE:
+ std::swap(TrueOpnd, FalseOpnd);
+ LLVM_FALLTHROUGH;
+ case ISD::SETOGT:
+ case ISD::SETUGT:
+ case ISD::SETOGE:
+ case ISD::SETUGE:
+ case ISD::SETGT:
+ case ISD::SETGE:
+ if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
+ TLI.isOperationLegal(ISD::FNEG, VT))
+ return DAG.getNode(ISD::FNEG, DL, VT,
+ DAG.getNode(ISD::FABS, DL, VT, X));
+ if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
+ return DAG.getNode(ISD::FABS, DL, VT, X);
+
+ break;
+ }
+ }
+ }
+
// FMUL -> FMA combines:
- if (SDValue Fused = visitFMULForFMACombine(N)) {
+ if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
AddToWorklist(Fused.getNode());
return Fused;
}
@@ -8484,17 +11437,21 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
EVT VT = N->getValueType(0);
- SDLoc dl(N);
+ SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
+ // FMA nodes have flags that propagate to the created nodes.
+ const SDNodeFlags Flags = N->getFlags();
+ bool UnsafeFPMath = Options.UnsafeFPMath || isContractable(N);
+
// Constant fold FMA.
if (isa<ConstantFPSDNode>(N0) &&
isa<ConstantFPSDNode>(N1) &&
isa<ConstantFPSDNode>(N2)) {
- return DAG.getNode(ISD::FMA, dl, VT, N0, N1, N2);
+ return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
}
- if (Options.UnsafeFPMath) {
+ if (UnsafeFPMath) {
if (N0CFP && N0CFP->isZero())
return N2;
if (N1CFP && N1CFP->isZero())
@@ -8511,29 +11468,24 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
!isConstantFPBuildVectorOrConstantFP(N1))
return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
- // TODO: FMA nodes should have flags that propagate to the created nodes.
- // For now, create a Flags object for use with all unsafe math transforms.
- SDNodeFlags Flags;
- Flags.setUnsafeAlgebra(true);
-
- if (Options.UnsafeFPMath) {
+ if (UnsafeFPMath) {
// (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
isConstantFPBuildVectorOrConstantFP(N1) &&
isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
- return DAG.getNode(ISD::FMUL, dl, VT, N0,
- DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1),
- &Flags), &Flags);
+ return DAG.getNode(ISD::FMUL, DL, VT, N0,
+ DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1),
+ Flags), Flags);
}
// (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
if (N0.getOpcode() == ISD::FMUL &&
isConstantFPBuildVectorOrConstantFP(N1) &&
isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
- return DAG.getNode(ISD::FMA, dl, VT,
+ return DAG.getNode(ISD::FMA, DL, VT,
N0.getOperand(0),
- DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1),
- &Flags),
+ DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1),
+ Flags),
N2);
}
}
@@ -8543,32 +11495,40 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
if (N1CFP) {
if (N1CFP->isExactlyValue(1.0))
// TODO: The FMA node should have flags that propagate to this node.
- return DAG.getNode(ISD::FADD, dl, VT, N0, N2);
+ return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
if (N1CFP->isExactlyValue(-1.0) &&
(!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
- SDValue RHSNeg = DAG.getNode(ISD::FNEG, dl, VT, N0);
+ SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
AddToWorklist(RHSNeg.getNode());
// TODO: The FMA node should have flags that propagate to this node.
- return DAG.getNode(ISD::FADD, dl, VT, N2, RHSNeg);
+ return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
+ }
+
+ // fma (fneg x), K, y -> fma x -K, y
+ if (N0.getOpcode() == ISD::FNEG &&
+ (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT)))) {
+ return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ DAG.getNode(ISD::FNEG, DL, VT, N1, Flags), N2);
}
}
- if (Options.UnsafeFPMath) {
+ if (UnsafeFPMath) {
// (fma x, c, x) -> (fmul x, (c+1))
if (N1CFP && N0 == N2) {
- return DAG.getNode(ISD::FMUL, dl, VT, N0,
- DAG.getNode(ISD::FADD, dl, VT,
- N1, DAG.getConstantFP(1.0, dl, VT),
- &Flags), &Flags);
+ return DAG.getNode(ISD::FMUL, DL, VT, N0,
+ DAG.getNode(ISD::FADD, DL, VT, N1,
+ DAG.getConstantFP(1.0, DL, VT), Flags),
+ Flags);
}
// (fma x, c, (fneg x)) -> (fmul x, (c-1))
if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
- return DAG.getNode(ISD::FMUL, dl, VT, N0,
- DAG.getNode(ISD::FADD, dl, VT,
- N1, DAG.getConstantFP(-1.0, dl, VT),
- &Flags), &Flags);
+ return DAG.getNode(ISD::FMUL, DL, VT, N0,
+ DAG.getNode(ISD::FADD, DL, VT, N1,
+ DAG.getConstantFP(-1.0, DL, VT), Flags),
+ Flags);
}
}
@@ -8578,14 +11538,14 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
// Combine multiple FDIVs with the same divisor into multiple FMULs by the
// reciprocal.
// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
-// Notice that this is not always beneficial. One reason is different target
+// Notice that this is not always beneficial. One reason is different targets
// may have different costs for FDIV and FMUL, so sometimes the cost of two
// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
- const SDNodeFlags *Flags = N->getFlags();
- if (!UnsafeMath && !Flags->hasAllowReciprocal())
+ const SDNodeFlags Flags = N->getFlags();
+ if (!UnsafeMath && !Flags.hasAllowReciprocal())
return SDValue();
// Skip if current node is a reciprocal.
@@ -8608,7 +11568,7 @@ SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
// This division is eligible for optimization only if global unsafe math
// is enabled or if this division allows reciprocal formation.
- if (UnsafeMath || U->getFlags()->hasAllowReciprocal())
+ if (UnsafeMath || U->getFlags().hasAllowReciprocal())
Users.insert(U);
}
}
@@ -8647,7 +11607,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector())
@@ -8658,11 +11618,14 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (N0CFP && N1CFP)
return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags);
- if (Options.UnsafeFPMath) {
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
+ if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
// fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
if (N1CFP) {
// Compute the reciprocal 1.0 / c2.
- APFloat N1APF = N1CFP->getValueAPF();
+ const APFloat &N1APF = N1CFP->getValueAPF();
APFloat Recip(N1APF.getSemantics(), 1); // 1.0
APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
// Only do the transform if the reciprocal is a legal fp immediate that
@@ -8671,8 +11634,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
(!LegalOperations ||
// FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
// backend)... we should handle this gracefully after Legalize.
- // TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT) ||
- TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) ||
+ // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
+ TLI.isOperationLegal(ISD::ConstantFP, VT) ||
TLI.isFPImmLegal(Recip, VT)))
return DAG.getNode(ISD::FMUL, DL, VT, N0,
DAG.getConstantFP(Recip, DL, VT), Flags);
@@ -8681,12 +11644,12 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
// If this FDIV is part of a reciprocal square root, it may be folded
// into a target-specific square root estimate instruction.
if (N1.getOpcode() == ISD::FSQRT) {
- if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) {
+ if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags)) {
return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
}
} else if (N1.getOpcode() == ISD::FP_EXTEND &&
N1.getOperand(0).getOpcode() == ISD::FSQRT) {
- if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0),
+ if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
Flags)) {
RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
AddToWorklist(RV.getNode());
@@ -8694,7 +11657,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
}
} else if (N1.getOpcode() == ISD::FP_ROUND &&
N1.getOperand(0).getOpcode() == ISD::FSQRT) {
- if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0),
+ if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
Flags)) {
RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
AddToWorklist(RV.getNode());
@@ -8715,7 +11678,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (SqrtOp.getNode()) {
// We found a FSQRT, so try to make this fold:
// x / (y * sqrt(z)) -> x * (rsqrt(z) / y)
- if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
+ if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags);
AddToWorklist(RV.getNode());
return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
@@ -8758,41 +11721,26 @@ SDValue DAGCombiner::visitFREM(SDNode *N) {
// fold (frem c1, c2) -> fmod(c1,c2)
if (N0CFP && N1CFP)
- return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1,
- &cast<BinaryWithFlagsSDNode>(N)->Flags);
+ return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, N->getFlags());
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
return SDValue();
}
SDValue DAGCombiner::visitFSQRT(SDNode *N) {
- if (!DAG.getTarget().Options.UnsafeFPMath || TLI.isFsqrtCheap())
+ SDNodeFlags Flags = N->getFlags();
+ if (!DAG.getTarget().Options.UnsafeFPMath &&
+ !Flags.hasApproximateFuncs())
return SDValue();
- // TODO: FSQRT nodes should have flags that propagate to the created nodes.
- // For now, create a Flags object for use with all unsafe math transforms.
- SDNodeFlags Flags;
- Flags.setUnsafeAlgebra(true);
-
- // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5)
- SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags);
- if (!RV)
+ SDValue N0 = N->getOperand(0);
+ if (TLI.isFsqrtCheap(N0, DAG))
return SDValue();
- EVT VT = RV.getValueType();
- SDLoc DL(N);
- RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags);
- AddToWorklist(RV.getNode());
-
- // Unfortunately, RV is now NaN if the input was exactly 0.
- // Select out this case and force the answer to 0.
- SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
- EVT CCVT = getSetCCResultType(VT);
- SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ);
- AddToWorklist(ZeroCmp.getNode());
- AddToWorklist(RV.getNode());
-
- return DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
- ZeroCmp, Zero, RV);
+ // FSQRT nodes have flags that propagate to the created nodes.
+ return buildSqrtEstimate(N0, Flags);
}
/// copysign(x, fp_extend(y)) -> copysign(x, y)
@@ -8806,7 +11754,7 @@ static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
// value in one SSE register, but instruction selection cannot handle
// FCOPYSIGN on SSE registers yet.
EVT N1VT = N1->getValueType(0);
- EVT N1Op0VT = N1->getOperand(0)->getValueType(0);
+ EVT N1Op0VT = N1->getOperand(0).getValueType();
return (N1VT == N1Op0VT || N1Op0VT != MVT::f128);
}
return false;
@@ -8815,15 +11763,15 @@ static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
- ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
- ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
+ bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0);
+ bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1);
EVT VT = N->getValueType(0);
- if (N0CFP && N1CFP) // Constant fold
+ if (N0CFP && N1CFP) // Constant fold
return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1);
- if (N1CFP) {
- const APFloat& V = N1CFP->getValueAPF();
+ if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
+ const APFloat &V = N1C->getValueAPF();
// copysign(x, c1) -> fabs(x) iff ispos(c1)
// copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
if (!V.isNegative()) {
@@ -8841,8 +11789,7 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
// copysign(copysign(x,z), y) -> copysign(x, y)
if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
N0.getOpcode() == ISD::FCOPYSIGN)
- return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
- N0.getOperand(0), N1);
+ return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
// copysign(x, abs(y)) -> abs(x)
if (N1.getOpcode() == ISD::FABS)
@@ -8850,14 +11797,113 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
// copysign(x, copysign(y,z)) -> copysign(x, z)
if (N1.getOpcode() == ISD::FCOPYSIGN)
- return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
- N0, N1.getOperand(1));
+ return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
// copysign(x, fp_extend(y)) -> copysign(x, y)
// copysign(x, fp_round(y)) -> copysign(x, y)
if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
- return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
- N0, N1.getOperand(0));
+ return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitFPOW(SDNode *N) {
+ ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
+ if (!ExponentC)
+ return SDValue();
+
+ // Try to convert x ** (1/3) into cube root.
+ // TODO: Handle the various flavors of long double.
+ // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
+ // Some range near 1/3 should be fine.
+ EVT VT = N->getValueType(0);
+ if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
+ (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
+ // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
+ // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
+ // pow(-val, 1/3) = nan; cbrt(-val) = -num.
+ // For regular numbers, rounding may cause the results to differ.
+ // Therefore, we require { nsz ninf nnan afn } for this transform.
+ // TODO: We could select out the special cases if we don't have nsz/ninf.
+ SDNodeFlags Flags = N->getFlags();
+ if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
+ !Flags.hasApproximateFuncs())
+ return SDValue();
+
+ // Do not create a cbrt() libcall if the target does not have it, and do not
+ // turn a pow that has lowering support into a cbrt() libcall.
+ if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
+ (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
+ DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
+ return SDValue();
+
+ return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0), Flags);
+ }
+
+ // Try to convert x ** (1/4) into square roots.
+ // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
+ // TODO: This could be extended (using a target hook) to handle smaller
+ // power-of-2 fractional exponents.
+ if (ExponentC->getValueAPF().isExactlyValue(0.25)) {
+ // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
+ // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
+ // For regular numbers, rounding may cause the results to differ.
+ // Therefore, we require { nsz ninf afn } for this transform.
+ // TODO: We could select out the special cases if we don't have nsz/ninf.
+ SDNodeFlags Flags = N->getFlags();
+ if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() ||
+ !Flags.hasApproximateFuncs())
+ return SDValue();
+
+ // Don't double the number of libcalls. We are trying to inline fast code.
+ if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
+ return SDValue();
+
+ // Assume that libcalls are the smallest code.
+ // TODO: This restriction should probably be lifted for vectors.
+ if (DAG.getMachineFunction().getFunction().optForSize())
+ return SDValue();
+
+ // pow(X, 0.25) --> sqrt(sqrt(X))
+ SDLoc DL(N);
+ SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0), Flags);
+ return DAG.getNode(ISD::FSQRT, DL, VT, Sqrt, Flags);
+ }
+
+ return SDValue();
+}
+
+static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
+ const TargetLowering &TLI) {
+ // This optimization is guarded by a function attribute because it may produce
+ // unexpected results. Ie, programs may be relying on the platform-specific
+ // undefined behavior when the float-to-int conversion overflows.
+ const Function &F = DAG.getMachineFunction().getFunction();
+ Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow");
+ if (StrictOverflow.getValueAsString().equals("false"))
+ return SDValue();
+
+ // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
+ // replacing casts with a libcall. We also must be allowed to ignore -0.0
+ // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
+ // conversions would return +0.0.
+ // FIXME: We should be able to use node-level FMF here.
+ // TODO: If strict math, should we use FABS (+ range check for signed cast)?
+ EVT VT = N->getValueType(0);
+ if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
+ !DAG.getTarget().Options.NoSignedZerosFPMath)
+ return SDValue();
+
+ // fptosi/fptoui round towards zero, so converting from FP to integer and
+ // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
+ SDValue N0 = N->getOperand(0);
+ if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
+ N0.getOperand(0).getValueType() == VT)
+ return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
+
+ if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
+ N0.getOperand(0).getValueType() == VT)
+ return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
return SDValue();
}
@@ -8868,16 +11914,16 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
EVT OpVT = N0.getValueType();
// fold (sint_to_fp c1) -> c1fp
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
// ...but only if the target supports immediate floating-point values
(!LegalOperations ||
- TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT)))
+ TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
// If the input is a legal type, and SINT_TO_FP is not legal on this target,
// but UINT_TO_FP is legal on this target, try to convert.
- if (!TLI.isOperationLegalOrCustom(ISD::SINT_TO_FP, OpVT) &&
- TLI.isOperationLegalOrCustom(ISD::UINT_TO_FP, OpVT)) {
+ if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
+ hasOperation(ISD::UINT_TO_FP, OpVT)) {
// If the sign bit is known to be zero, we can change this to UINT_TO_FP.
if (DAG.SignBitIsZero(N0))
return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
@@ -8889,7 +11935,7 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
!VT.isVector() &&
(!LegalOperations ||
- TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) {
+ TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
SDLoc DL(N);
SDValue Ops[] =
{ N0.getOperand(0), N0.getOperand(1),
@@ -8903,7 +11949,7 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
if (N0.getOpcode() == ISD::ZERO_EXTEND &&
N0.getOperand(0).getOpcode() == ISD::SETCC &&!VT.isVector() &&
(!LegalOperations ||
- TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) {
+ TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
SDLoc DL(N);
SDValue Ops[] =
{ N0.getOperand(0).getOperand(0), N0.getOperand(0).getOperand(1),
@@ -8913,6 +11959,9 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
}
}
+ if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
+ return FTrunc;
+
return SDValue();
}
@@ -8922,16 +11971,16 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
EVT OpVT = N0.getValueType();
// fold (uint_to_fp c1) -> c1fp
- if (isConstantIntBuildVectorOrConstantInt(N0) &&
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
// ...but only if the target supports immediate floating-point values
(!LegalOperations ||
- TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT)))
+ TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
// If the input is a legal type, and UINT_TO_FP is not legal on this target,
// but SINT_TO_FP is legal on this target, try to convert.
- if (!TLI.isOperationLegalOrCustom(ISD::UINT_TO_FP, OpVT) &&
- TLI.isOperationLegalOrCustom(ISD::SINT_TO_FP, OpVT)) {
+ if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
+ hasOperation(ISD::SINT_TO_FP, OpVT)) {
// If the sign bit is known to be zero, we can change this to SINT_TO_FP.
if (DAG.SignBitIsZero(N0))
return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
@@ -8940,10 +11989,9 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
// The next optimizations are desirable only if SELECT_CC can be lowered.
if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) {
// fold (uint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc)
-
if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
(!LegalOperations ||
- TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) {
+ TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
SDLoc DL(N);
SDValue Ops[] =
{ N0.getOperand(0), N0.getOperand(1),
@@ -8953,6 +12001,9 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
}
}
+ if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
+ return FTrunc;
+
return SDValue();
}
@@ -8993,9 +12044,7 @@ static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
}
if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
- if (SrcVT == VT)
- return Src;
- return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Src);
+ return DAG.getBitcast(VT, Src);
}
return SDValue();
}
@@ -9039,7 +12088,18 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
// fold (fp_round (fp_round x)) -> (fp_round x)
if (N0.getOpcode() == ISD::FP_ROUND) {
const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
- const bool N0IsTrunc = N0.getNode()->getConstantOperandVal(1) == 1;
+ const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
+
+ // Skip this folding if it results in an fp_round from f80 to f16.
+ //
+ // f80 to f16 always generates an expensive (and as yet, unimplemented)
+ // libcall to __truncxfhf2 instead of selecting native f16 conversion
+ // instructions from f32 or f64. Moreover, the first (value-preserving)
+ // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
+ // x86.
+ if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
+ return SDValue();
+
// If the first fp_round isn't a value preserving truncation, it might
// introduce a tie in the second fp_round, that wouldn't occur in the
// single-step fp_round we want to fold to.
@@ -9061,6 +12121,9 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
Tmp, N0.getOperand(1));
}
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -9101,7 +12164,7 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
// Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
// value of X.
if (N0.getOpcode() == ISD::FP_ROUND
- && N0.getNode()->getConstantOperandVal(1) == 1) {
+ && N0.getConstantOperandVal(1) == 1) {
SDValue In = N0.getOperand(0);
if (In.getValueType() == VT) return In;
if (VT.bitsLT(In.getValueType()))
@@ -9127,6 +12190,9 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -9149,6 +12215,19 @@ SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
if (isConstantFPBuildVectorOrConstantFP(N0))
return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
+ // fold ftrunc (known rounded int x) -> x
+ // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
+ // likely to be generated to extract integer from a rounded floating value.
+ switch (N0.getOpcode()) {
+ default: break;
+ case ISD::FRINT:
+ case ISD::FTRUNC:
+ case ISD::FNEARBYINT:
+ case ISD::FFLOOR:
+ case ISD::FCEIL:
+ return N0;
+ }
+
return SDValue();
}
@@ -9188,17 +12267,17 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) {
if (N0.getValueType().isVector()) {
// For a vector, get a mask such as 0x80... per scalar element
// and splat it.
- SignMask = APInt::getSignBit(N0.getValueType().getScalarSizeInBits());
+ SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
} else {
// For a scalar, just generate 0x80...
- SignMask = APInt::getSignBit(IntVT.getSizeInBits());
+ SignMask = APInt::getSignMask(IntVT.getSizeInBits());
}
SDLoc DL0(N0);
Int = DAG.getNode(ISD::XOR, DL0, IntVT, Int,
DAG.getConstant(SignMask, DL0, IntVT));
AddToWorklist(Int.getNode());
- return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Int);
+ return DAG.getBitcast(VT, Int);
}
}
@@ -9212,17 +12291,18 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) {
if (Level >= AfterLegalizeDAG &&
(TLI.isFPImmLegal(CVal, VT) ||
TLI.isOperationLegal(ISD::ConstantFP, VT)))
- return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
- DAG.getNode(ISD::FNEG, SDLoc(N), VT,
- N0.getOperand(1)),
- &cast<BinaryWithFlagsSDNode>(N0)->Flags);
+ return DAG.getNode(
+ ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
+ DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1)),
+ N0->getFlags());
}
}
return SDValue();
}
-SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
+static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N,
+ APFloat (*Op)(const APFloat &, const APFloat &)) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
@@ -9232,36 +12312,31 @@ SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
if (N0CFP && N1CFP) {
const APFloat &C0 = N0CFP->getValueAPF();
const APFloat &C1 = N1CFP->getValueAPF();
- return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), VT);
+ return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT);
}
// Canonicalize to constant on RHS.
if (isConstantFPBuildVectorOrConstantFP(N0) &&
- !isConstantFPBuildVectorOrConstantFP(N1))
- return DAG.getNode(ISD::FMINNUM, SDLoc(N), VT, N1, N0);
+ !isConstantFPBuildVectorOrConstantFP(N1))
+ return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
return SDValue();
}
-SDValue DAGCombiner::visitFMAXNUM(SDNode *N) {
- SDValue N0 = N->getOperand(0);
- SDValue N1 = N->getOperand(1);
- EVT VT = N->getValueType(0);
- const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
- const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
+SDValue DAGCombiner::visitFMINNUM(SDNode *N) {
+ return visitFMinMax(DAG, N, minnum);
+}
- if (N0CFP && N1CFP) {
- const APFloat &C0 = N0CFP->getValueAPF();
- const APFloat &C1 = N1CFP->getValueAPF();
- return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), VT);
- }
+SDValue DAGCombiner::visitFMAXNUM(SDNode *N) {
+ return visitFMinMax(DAG, N, maxnum);
+}
- // Canonicalize to constant on RHS.
- if (isConstantFPBuildVectorOrConstantFP(N0) &&
- !isConstantFPBuildVectorOrConstantFP(N1))
- return DAG.getNode(ISD::FMAXNUM, SDLoc(N), VT, N1, N0);
+SDValue DAGCombiner::visitFMINIMUM(SDNode *N) {
+ return visitFMinMax(DAG, N, minimum);
+}
- return SDValue();
+SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) {
+ return visitFMinMax(DAG, N, maximum);
}
SDValue DAGCombiner::visitFABS(SDNode *N) {
@@ -9281,11 +12356,8 @@ SDValue DAGCombiner::visitFABS(SDNode *N) {
if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
- // Transform fabs(bitconvert(x)) -> bitconvert(x & ~sign) to avoid loading
- // constant pool values.
- if (!TLI.isFAbsFree(VT) &&
- N0.getOpcode() == ISD::BITCAST &&
- N0.getNode()->hasOneUse()) {
+ // fabs(bitcast(x)) -> bitcast(x & ~sign) to avoid constant pool loads.
+ if (!TLI.isFAbsFree(VT) && N0.getOpcode() == ISD::BITCAST && N0.hasOneUse()) {
SDValue Int = N0.getOperand(0);
EVT IntVT = Int.getValueType();
if (IntVT.isInteger() && !IntVT.isVector()) {
@@ -9293,17 +12365,17 @@ SDValue DAGCombiner::visitFABS(SDNode *N) {
if (N0.getValueType().isVector()) {
// For a vector, get a mask such as 0x7f... per scalar element
// and splat it.
- SignMask = ~APInt::getSignBit(N0.getValueType().getScalarSizeInBits());
+ SignMask = ~APInt::getSignMask(N0.getScalarValueSizeInBits());
SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
} else {
// For a scalar, just generate 0x7f...
- SignMask = ~APInt::getSignBit(IntVT.getSizeInBits());
+ SignMask = ~APInt::getSignMask(IntVT.getSizeInBits());
}
SDLoc DL(N0);
Int = DAG.getNode(ISD::AND, DL, IntVT, Int,
DAG.getConstant(SignMask, DL, IntVT));
AddToWorklist(Int.getNode());
- return DAG.getNode(ISD::BITCAST, SDLoc(N), N->getValueType(0), Int);
+ return DAG.getBitcast(N->getValueType(0), Int);
}
}
@@ -9331,16 +12403,22 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
N1.getOperand(0), N1.getOperand(1), N2);
}
- if ((N1.hasOneUse() && N1.getOpcode() == ISD::SRL) ||
- ((N1.getOpcode() == ISD::TRUNCATE && N1.hasOneUse()) &&
- (N1.getOperand(0).hasOneUse() &&
- N1.getOperand(0).getOpcode() == ISD::SRL))) {
- SDNode *Trunc = nullptr;
- if (N1.getOpcode() == ISD::TRUNCATE) {
- // Look pass the truncate.
- Trunc = N1.getNode();
- N1 = N1.getOperand(0);
- }
+ if (N1.hasOneUse()) {
+ if (SDValue NewN1 = rebuildSetCC(N1))
+ return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain, NewN1, N2);
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::rebuildSetCC(SDValue N) {
+ if (N.getOpcode() == ISD::SRL ||
+ (N.getOpcode() == ISD::TRUNCATE &&
+ (N.getOperand(0).hasOneUse() &&
+ N.getOperand(0).getOpcode() == ISD::SRL))) {
+ // Look pass the truncate.
+ if (N.getOpcode() == ISD::TRUNCATE)
+ N = N.getOperand(0);
// Match this pattern so that we can generate simpler code:
//
@@ -9359,74 +12437,55 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
// This applies only when the AND constant value has one bit set and the
// SRL constant is equal to the log2 of the AND constant. The back-end is
// smart enough to convert the result into a TEST/JMP sequence.
- SDValue Op0 = N1.getOperand(0);
- SDValue Op1 = N1.getOperand(1);
+ SDValue Op0 = N.getOperand(0);
+ SDValue Op1 = N.getOperand(1);
- if (Op0.getOpcode() == ISD::AND &&
- Op1.getOpcode() == ISD::Constant) {
+ if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
SDValue AndOp1 = Op0.getOperand(1);
if (AndOp1.getOpcode() == ISD::Constant) {
const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
if (AndConst.isPowerOf2() &&
- cast<ConstantSDNode>(Op1)->getAPIntValue()==AndConst.logBase2()) {
+ cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
SDLoc DL(N);
- SDValue SetCC =
- DAG.getSetCC(DL,
- getSetCCResultType(Op0.getValueType()),
- Op0, DAG.getConstant(0, DL, Op0.getValueType()),
- ISD::SETNE);
-
- SDValue NewBRCond = DAG.getNode(ISD::BRCOND, DL,
- MVT::Other, Chain, SetCC, N2);
- // Don't add the new BRCond into the worklist or else SimplifySelectCC
- // will convert it back to (X & C1) >> C2.
- CombineTo(N, NewBRCond, false);
- // Truncate is dead.
- if (Trunc)
- deleteAndRecombine(Trunc);
- // Replace the uses of SRL with SETCC
- WorklistRemover DeadNodes(*this);
- DAG.ReplaceAllUsesOfValueWith(N1, SetCC);
- deleteAndRecombine(N1.getNode());
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
+ Op0, DAG.getConstant(0, DL, Op0.getValueType()),
+ ISD::SETNE);
}
}
}
-
- if (Trunc)
- // Restore N1 if the above transformation doesn't match.
- N1 = N->getOperand(1);
}
// Transform br(xor(x, y)) -> br(x != y)
// Transform br(xor(xor(x,y), 1)) -> br (x == y)
- if (N1.hasOneUse() && N1.getOpcode() == ISD::XOR) {
- SDNode *TheXor = N1.getNode();
+ if (N.getOpcode() == ISD::XOR) {
+ // Because we may call this on a speculatively constructed
+ // SimplifiedSetCC Node, we need to simplify this node first.
+ // Ideally this should be folded into SimplifySetCC and not
+ // here. For now, grab a handle to N so we don't lose it from
+ // replacements interal to the visit.
+ HandleSDNode XORHandle(N);
+ while (N.getOpcode() == ISD::XOR) {
+ SDValue Tmp = visitXOR(N.getNode());
+ // No simplification done.
+ if (!Tmp.getNode())
+ break;
+ // Returning N is form in-visit replacement that may invalidated
+ // N. Grab value from Handle.
+ if (Tmp.getNode() == N.getNode())
+ N = XORHandle.getValue();
+ else // Node simplified. Try simplifying again.
+ N = Tmp;
+ }
+
+ if (N.getOpcode() != ISD::XOR)
+ return N;
+
+ SDNode *TheXor = N.getNode();
+
SDValue Op0 = TheXor->getOperand(0);
SDValue Op1 = TheXor->getOperand(1);
- if (Op0.getOpcode() == Op1.getOpcode()) {
- // Avoid missing important xor optimizations.
- if (SDValue Tmp = visitXOR(TheXor)) {
- if (Tmp.getNode() != TheXor) {
- DEBUG(dbgs() << "\nReplacing.8 ";
- TheXor->dump(&DAG);
- dbgs() << "\nWith: ";
- Tmp.getNode()->dump(&DAG);
- dbgs() << '\n');
- WorklistRemover DeadNodes(*this);
- DAG.ReplaceAllUsesOfValueWith(N1, Tmp);
- deleteAndRecombine(TheXor);
- return DAG.getNode(ISD::BRCOND, SDLoc(N),
- MVT::Other, Chain, Tmp, N2);
- }
-
- // visitXOR has changed XOR's operands or replaced the XOR completely,
- // bail out.
- return SDValue(N, 0);
- }
- }
if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
bool Equal = false;
@@ -9436,19 +12495,12 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
Equal = true;
}
- EVT SetCCVT = N1.getValueType();
+ EVT SetCCVT = N.getValueType();
if (LegalTypes)
SetCCVT = getSetCCResultType(SetCCVT);
- SDValue SetCC = DAG.getSetCC(SDLoc(TheXor),
- SetCCVT,
- Op0, Op1,
- Equal ? ISD::SETEQ : ISD::SETNE);
// Replace the uses of XOR with SETCC
- WorklistRemover DeadNodes(*this);
- DAG.ReplaceAllUsesOfValueWith(N1, SetCC);
- deleteAndRecombine(N1.getNode());
- return DAG.getNode(ISD::BRCOND, SDLoc(N),
- MVT::Other, Chain, SetCC, N2);
+ return DAG.getSetCC(SDLoc(TheXor), SetCCVT, Op0, Op1,
+ Equal ? ISD::SETEQ : ISD::SETNE);
}
}
@@ -9607,6 +12659,11 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
return false;
}
+ // Caches for hasPredecessorHelper.
+ SmallPtrSet<const SDNode *, 32> Visited;
+ SmallVector<const SDNode *, 16> Worklist;
+ Worklist.push_back(N);
+
// If the offset is a constant, there may be other adds of constants that
// can be folded with this one. We should do this to avoid having to keep
// a copy of the original base pointer.
@@ -9621,7 +12678,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
continue;
- if (Use.getUser()->isPredecessorOf(N))
+ if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
continue;
if (Use.getUser()->getOpcode() != ISD::ADD &&
@@ -9651,14 +12708,10 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
// Now check for #3 and #4.
bool RealUse = false;
- // Caches for hasPredecessorHelper
- SmallPtrSet<const SDNode *, 32> Visited;
- SmallVector<const SDNode *, 16> Worklist;
-
for (SDNode *Use : Ptr.getNode()->uses()) {
if (Use == N)
continue;
- if (N->hasPredecessorHelper(Use, Visited, Worklist))
+ if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
return false;
// If Ptr may be folded in addressing mode of other use, then it's
@@ -9679,11 +12732,8 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
BasePtr, Offset, AM);
++PreIndexedNodes;
++NodesCombined;
- DEBUG(dbgs() << "\nReplacing.4 ";
- N->dump(&DAG);
- dbgs() << "\nWith: ";
- Result.getNode()->dump(&DAG);
- dbgs() << '\n');
+ LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
+ Result.getNode()->dump(&DAG); dbgs() << '\n');
WorklistRemover DeadNodes(*this);
if (isLoad) {
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
@@ -9712,7 +12762,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
// x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
//
// where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
- // indexed load/store and the expresion that needs to be re-written.
+ // indexed load/store and the expression that needs to be re-written.
//
// Therefore, we have:
// t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
@@ -9720,7 +12770,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
ConstantSDNode *CN =
cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
int X0, X1, Y0, Y1;
- APInt Offset0 = CN->getAPIntValue();
+ const APInt &Offset0 = CN->getAPIntValue();
APInt Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
@@ -9751,6 +12801,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
// Replace the uses of Ptr with uses of the updated base value.
DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(isLoad ? 1 : 0));
deleteAndRecombine(Ptr.getNode());
+ AddToWorklist(Result.getNode());
return true;
}
@@ -9838,8 +12889,15 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
if (TryNext)
continue;
- // Check for #2
- if (!Op->isPredecessorOf(N) && !N->isPredecessorOf(Op)) {
+ // Check for #2.
+ SmallPtrSet<const SDNode *, 32> Visited;
+ SmallVector<const SDNode *, 8> Worklist;
+ // Ptr is predecessor to both N and Op.
+ Visited.insert(Ptr.getNode());
+ Worklist.push_back(N);
+ Worklist.push_back(Op);
+ if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
+ !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) {
SDValue Result = isLoad
? DAG.getIndexedLoad(SDValue(N,0), SDLoc(N),
BasePtr, Offset, AM)
@@ -9847,11 +12905,9 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
BasePtr, Offset, AM);
++PostIndexedNodes;
++NodesCombined;
- DEBUG(dbgs() << "\nReplacing.5 ";
- N->dump(&DAG);
- dbgs() << "\nWith: ";
- Result.getNode()->dump(&DAG);
- dbgs() << '\n');
+ LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG);
+ dbgs() << "\nWith: "; Result.getNode()->dump(&DAG);
+ dbgs() << '\n');
WorklistRemover DeadNodes(*this);
if (isLoad) {
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
@@ -9875,7 +12931,7 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
return false;
}
-/// \brief Return the base-pointer arithmetic from an indexed \p LD.
+/// Return the base-pointer arithmetic from an indexed \p LD.
SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
ISD::MemIndexedMode AM = LD->getAddressingMode();
assert(AM != ISD::UNINDEXED);
@@ -9899,6 +12955,157 @@ SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
}
+static inline int numVectorEltsOrZero(EVT T) {
+ return T.isVector() ? T.getVectorNumElements() : 0;
+}
+
+bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
+ Val = ST->getValue();
+ EVT STType = Val.getValueType();
+ EVT STMemType = ST->getMemoryVT();
+ if (STType == STMemType)
+ return true;
+ if (isTypeLegal(STMemType))
+ return false; // fail.
+ if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
+ TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
+ Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
+ return true;
+ }
+ if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
+ STType.isInteger() && STMemType.isInteger()) {
+ Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
+ return true;
+ }
+ if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
+ Val = DAG.getBitcast(STMemType, Val);
+ return true;
+ }
+ return false; // fail.
+}
+
+bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
+ EVT LDMemType = LD->getMemoryVT();
+ EVT LDType = LD->getValueType(0);
+ assert(Val.getValueType() == LDMemType &&
+ "Attempting to extend value of non-matching type");
+ if (LDType == LDMemType)
+ return true;
+ if (LDMemType.isInteger() && LDType.isInteger()) {
+ switch (LD->getExtensionType()) {
+ case ISD::NON_EXTLOAD:
+ Val = DAG.getBitcast(LDType, Val);
+ return true;
+ case ISD::EXTLOAD:
+ Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
+ return true;
+ case ISD::SEXTLOAD:
+ Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
+ return true;
+ case ISD::ZEXTLOAD:
+ Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
+ return true;
+ }
+ }
+ return false;
+}
+
+SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
+ if (OptLevel == CodeGenOpt::None || LD->isVolatile())
+ return SDValue();
+ SDValue Chain = LD->getOperand(0);
+ StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
+ if (!ST || ST->isVolatile())
+ return SDValue();
+
+ EVT LDType = LD->getValueType(0);
+ EVT LDMemType = LD->getMemoryVT();
+ EVT STMemType = ST->getMemoryVT();
+ EVT STType = ST->getValue().getValueType();
+
+ BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
+ BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
+ int64_t Offset;
+ if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
+ return SDValue();
+
+ // Normalize for Endianness. After this Offset=0 will denote that the least
+ // significant bit in the loaded value maps to the least significant bit in
+ // the stored value). With Offset=n (for n > 0) the loaded value starts at the
+ // n:th least significant byte of the stored value.
+ if (DAG.getDataLayout().isBigEndian())
+ Offset = (STMemType.getStoreSizeInBits() -
+ LDMemType.getStoreSizeInBits()) / 8 - Offset;
+
+ // Check that the stored value cover all bits that are loaded.
+ bool STCoversLD =
+ (Offset >= 0) &&
+ (Offset * 8 + LDMemType.getSizeInBits() <= STMemType.getSizeInBits());
+
+ auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
+ if (LD->isIndexed()) {
+ bool IsSub = (LD->getAddressingMode() == ISD::PRE_DEC ||
+ LD->getAddressingMode() == ISD::POST_DEC);
+ unsigned Opc = IsSub ? ISD::SUB : ISD::ADD;
+ SDValue Idx = DAG.getNode(Opc, SDLoc(LD), LD->getOperand(1).getValueType(),
+ LD->getOperand(1), LD->getOperand(2));
+ SDValue Ops[] = {Val, Idx, Chain};
+ return CombineTo(LD, Ops, 3);
+ }
+ return CombineTo(LD, Val, Chain);
+ };
+
+ if (!STCoversLD)
+ return SDValue();
+
+ // Memory as copy space (potentially masked).
+ if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
+ // Simple case: Direct non-truncating forwarding
+ if (LDType.getSizeInBits() == LDMemType.getSizeInBits())
+ return ReplaceLd(LD, ST->getValue(), Chain);
+ // Can we model the truncate and extension with an and mask?
+ if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
+ !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
+ // Mask to size of LDMemType
+ auto Mask =
+ DAG.getConstant(APInt::getLowBitsSet(STType.getSizeInBits(),
+ STMemType.getSizeInBits()),
+ SDLoc(ST), STType);
+ auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
+ return ReplaceLd(LD, Val, Chain);
+ }
+ }
+
+ // TODO: Deal with nonzero offset.
+ if (LD->getBasePtr().isUndef() || Offset != 0)
+ return SDValue();
+ // Model necessary truncations / extenstions.
+ SDValue Val;
+ // Truncate Value To Stored Memory Size.
+ do {
+ if (!getTruncatedStoreValue(ST, Val))
+ continue;
+ if (!isTypeLegal(LDMemType))
+ continue;
+ if (STMemType != LDMemType) {
+ // TODO: Support vectors? This requires extract_subvector/bitcast.
+ if (!STMemType.isVector() && !LDMemType.isVector() &&
+ STMemType.isInteger() && LDMemType.isInteger())
+ Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
+ else
+ continue;
+ }
+ if (!extendLoadedValueToExtension(LD, Val))
+ continue;
+ return ReplaceLd(LD, Val, Chain);
+ } while (false);
+
+ // On failure, cleanup dead nodes we may have created.
+ if (Val->use_empty())
+ deleteAndRecombine(Val.getNode());
+ return SDValue();
+}
+
SDValue DAGCombiner::visitLOAD(SDNode *N) {
LoadSDNode *LD = cast<LoadSDNode>(N);
SDValue Chain = LD->getChain();
@@ -9917,14 +13124,12 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
// v3 = add v2, c
// Now we replace use of chain2 with chain1. This makes the second load
// isomorphic to the one we are deleting, and thus makes this load live.
- DEBUG(dbgs() << "\nReplacing.6 ";
- N->dump(&DAG);
- dbgs() << "\nWith chain: ";
- Chain.getNode()->dump(&DAG);
- dbgs() << "\n");
+ LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
+ dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG);
+ dbgs() << "\n");
WorklistRemover DeadNodes(*this);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
-
+ AddUsersToWorklist(Chain.getNode());
if (N->use_empty())
deleteAndRecombine(N);
@@ -9952,11 +13157,9 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
AddUsersToWorklist(N);
} else
Index = DAG.getUNDEF(N->getValueType(1));
- DEBUG(dbgs() << "\nReplacing.7 ";
- N->dump(&DAG);
- dbgs() << "\nWith: ";
- Undef.getNode()->dump(&DAG);
- dbgs() << " and 2 other values\n");
+ LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
+ dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG);
+ dbgs() << " and 2 other values\n");
WorklistRemover DeadNodes(*this);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
@@ -9969,42 +13172,25 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
// If this load is directly stored, replace the load value with the stored
// value.
- // TODO: Handle store large -> read small portion.
- // TODO: Handle TRUNCSTORE/LOADEXT
- if (ISD::isNormalLoad(N) && !LD->isVolatile()) {
- if (ISD::isNON_TRUNCStore(Chain.getNode())) {
- StoreSDNode *PrevST = cast<StoreSDNode>(Chain);
- if (PrevST->getBasePtr() == Ptr &&
- PrevST->getValue().getValueType() == N->getValueType(0))
- return CombineTo(N, Chain.getOperand(1), Chain);
- }
- }
+ if (auto V = ForwardStoreValueToDirectLoad(LD))
+ return V;
// Try to infer better alignment information than the load already has.
if (OptLevel != CodeGenOpt::None && LD->isUnindexed()) {
if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
- if (Align > LD->getMemOperand()->getBaseAlignment()) {
- SDValue NewLoad =
- DAG.getExtLoad(LD->getExtensionType(), SDLoc(N),
- LD->getValueType(0),
- Chain, Ptr, LD->getPointerInfo(),
- LD->getMemoryVT(),
- LD->isVolatile(), LD->isNonTemporal(),
- LD->isInvariant(), Align, LD->getAAInfo());
- if (NewLoad.getNode() != N)
- return CombineTo(N, NewLoad, SDValue(NewLoad.getNode(), 1), true);
+ if (Align > LD->getAlignment() && LD->getSrcValueOffset() % Align == 0) {
+ SDValue NewLoad = DAG.getExtLoad(
+ LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
+ LD->getPointerInfo(), LD->getMemoryVT(), Align,
+ LD->getMemOperand()->getFlags(), LD->getAAInfo());
+ // NewLoad will always be N as we are only refining the alignment
+ assert(NewLoad.getNode() == N);
+ (void)NewLoad;
}
}
}
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
-#ifndef NDEBUG
- if (CombinerAAOnlyFunc.getNumOccurrences() &&
- CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
- UseAA = false;
-#endif
- if (UseAA && LD->isUnindexed()) {
+ if (LD->isUnindexed()) {
// Walk up chain skipping non-aliasing memory nodes.
SDValue BetterChain = FindBetterChain(N, Chain);
@@ -10027,12 +13213,8 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
MVT::Other, Chain, ReplLoad.getValue(1));
- // Make sure the new and old chains are cleaned up.
- AddToWorklist(Token.getNode());
-
- // Replace uses with load result and token factor. Don't add users
- // to work list.
- return CombineTo(N, ReplLoad.getValue(0), Token, false);
+ // Replace uses with load result and token factor
+ return CombineTo(N, ReplLoad.getValue(0), Token);
}
}
@@ -10049,38 +13231,37 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
}
namespace {
-/// \brief Helper structure used to slice a load in smaller loads.
+
+/// Helper structure used to slice a load in smaller loads.
/// Basically a slice is obtained from the following sequence:
/// Origin = load Ty1, Base
/// Shift = srl Ty1 Origin, CstTy Amount
/// Inst = trunc Shift to Ty2
///
-/// Then, it will be rewriten into:
+/// Then, it will be rewritten into:
/// Slice = load SliceTy, Base + SliceOffset
/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
///
/// SliceTy is deduced from the number of bits that are actually used to
/// build Inst.
struct LoadedSlice {
- /// \brief Helper structure used to compute the cost of a slice.
+ /// Helper structure used to compute the cost of a slice.
struct Cost {
/// Are we optimizing for code size.
bool ForCodeSize;
+
/// Various cost.
- unsigned Loads;
- unsigned Truncates;
- unsigned CrossRegisterBanksCopies;
- unsigned ZExts;
- unsigned Shift;
+ unsigned Loads = 0;
+ unsigned Truncates = 0;
+ unsigned CrossRegisterBanksCopies = 0;
+ unsigned ZExts = 0;
+ unsigned Shift = 0;
- Cost(bool ForCodeSize = false)
- : ForCodeSize(ForCodeSize), Loads(0), Truncates(0),
- CrossRegisterBanksCopies(0), ZExts(0), Shift(0) {}
+ Cost(bool ForCodeSize = false) : ForCodeSize(ForCodeSize) {}
- /// \brief Get the cost of one isolated slice.
+ /// Get the cost of one isolated slice.
Cost(const LoadedSlice &LS, bool ForCodeSize = false)
- : ForCodeSize(ForCodeSize), Loads(1), Truncates(0),
- CrossRegisterBanksCopies(0), ZExts(0), Shift(0) {
+ : ForCodeSize(ForCodeSize), Loads(1) {
EVT TruncType = LS.Inst->getValueType(0);
EVT LoadedType = LS.getLoadedType();
if (TruncType != LoadedType &&
@@ -10088,7 +13269,7 @@ struct LoadedSlice {
ZExts = 1;
}
- /// \brief Account for slicing gain in the current cost.
+ /// Account for slicing gain in the current cost.
/// Slicing provide a few gains like removing a shift or a
/// truncate. This method allows to grow the cost of the original
/// load with the gain from this slice.
@@ -10142,13 +13323,17 @@ struct LoadedSlice {
bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
};
+
// The last instruction that represent the slice. This should be a
// truncate instruction.
SDNode *Inst;
+
// The original load instruction.
LoadSDNode *Origin;
+
// The right shift amount in bits from the original load.
unsigned Shift;
+
// The DAG from which Origin came from.
// This is used to get some contextual information about legal types, etc.
SelectionDAG *DAG;
@@ -10157,7 +13342,7 @@ struct LoadedSlice {
unsigned Shift = 0, SelectionDAG *DAG = nullptr)
: Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
- /// \brief Get the bits used in a chunk of bits \p BitWidth large.
+ /// Get the bits used in a chunk of bits \p BitWidth large.
/// \return Result is \p BitWidth and has used bits set to 1 and
/// not used bits set to 0.
APInt getUsedBits() const {
@@ -10177,14 +13362,14 @@ struct LoadedSlice {
return UsedBits;
}
- /// \brief Get the size of the slice to be loaded in bytes.
+ /// Get the size of the slice to be loaded in bytes.
unsigned getLoadedSize() const {
unsigned SliceSize = getUsedBits().countPopulation();
assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
return SliceSize / 8;
}
- /// \brief Get the type that will be loaded for this slice.
+ /// Get the type that will be loaded for this slice.
/// Note: This may not be the final type for the slice.
EVT getLoadedType() const {
assert(DAG && "Missing context");
@@ -10192,7 +13377,7 @@ struct LoadedSlice {
return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
}
- /// \brief Get the alignment of the load used for this slice.
+ /// Get the alignment of the load used for this slice.
unsigned getAlignment() const {
unsigned Alignment = Origin->getAlignment();
unsigned Offset = getOffsetFromBase();
@@ -10201,14 +13386,14 @@ struct LoadedSlice {
return Alignment;
}
- /// \brief Check if this slice can be rewritten with legal operations.
+ /// Check if this slice can be rewritten with legal operations.
bool isLegal() const {
// An invalid slice is not legal.
if (!Origin || !Inst || !DAG)
return false;
// Offsets are for indexed load only, we do not handle that.
- if (Origin->getOffset().getOpcode() != ISD::UNDEF)
+ if (!Origin->getOffset().isUndef())
return false;
const TargetLowering &TLI = DAG->getTargetLoweringInfo();
@@ -10245,7 +13430,7 @@ struct LoadedSlice {
return true;
}
- /// \brief Get the offset in bytes of this slice in the original chunk of
+ /// Get the offset in bytes of this slice in the original chunk of
/// bits.
/// \pre DAG != nullptr.
uint64_t getOffsetFromBase() const {
@@ -10266,7 +13451,7 @@ struct LoadedSlice {
return Offset;
}
- /// \brief Generate the sequence of instructions to load the slice
+ /// Generate the sequence of instructions to load the slice
/// represented by this object and redirect the uses of this slice to
/// this new sequence of instructions.
/// \pre this->Inst && this->Origin are valid Instructions and this
@@ -10276,7 +13461,7 @@ struct LoadedSlice {
assert(Inst && Origin && "Unable to replace a non-existing slice.");
const SDValue &OldBaseAddr = Origin->getBasePtr();
SDValue BaseAddr = OldBaseAddr;
- // Get the offset in that chunk of bytes w.r.t. the endianess.
+ // Get the offset in that chunk of bytes w.r.t. the endianness.
int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
assert(Offset >= 0 && "Offset too big to fit in int64_t!");
if (Offset) {
@@ -10291,10 +13476,10 @@ struct LoadedSlice {
EVT SliceType = getLoadedType();
// Create the load for the slice.
- SDValue LastInst = DAG->getLoad(
- SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
- Origin->getPointerInfo().getWithOffset(Offset), Origin->isVolatile(),
- Origin->isNonTemporal(), Origin->isInvariant(), getAlignment());
+ SDValue LastInst =
+ DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
+ Origin->getPointerInfo().getWithOffset(Offset),
+ getAlignment(), Origin->getMemOperand()->getFlags());
// If the final type is not the same as the loaded type, this means that
// we have to pad with zero. Create a zero extend for that.
EVT FinalType = Inst->getValueType(0);
@@ -10304,7 +13489,7 @@ struct LoadedSlice {
return LastInst;
}
- /// \brief Check if this slice can be merged with an expensive cross register
+ /// Check if this slice can be merged with an expensive cross register
/// bank copy. E.g.,
/// i = load i32
/// f = bitcast i32 i to float
@@ -10350,9 +13535,10 @@ struct LoadedSlice {
return true;
}
};
-}
-/// \brief Check that all bits set in \p UsedBits form a dense region, i.e.,
+} // end anonymous namespace
+
+/// Check that all bits set in \p UsedBits form a dense region, i.e.,
/// \p UsedBits looks like 0..0 1..1 0..0.
static bool areUsedBitsDense(const APInt &UsedBits) {
// If all the bits are one, this is dense!
@@ -10368,7 +13554,7 @@ static bool areUsedBitsDense(const APInt &UsedBits) {
return NarrowedUsedBits.isAllOnesValue();
}
-/// \brief Check whether or not \p First and \p Second are next to each other
+/// Check whether or not \p First and \p Second are next to each other
/// in memory. This means that there is no hole between the bits loaded
/// by \p First and the bits loaded by \p Second.
static bool areSlicesNextToEachOther(const LoadedSlice &First,
@@ -10382,7 +13568,7 @@ static bool areSlicesNextToEachOther(const LoadedSlice &First,
return areUsedBitsDense(UsedBits);
}
-/// \brief Adjust the \p GlobalLSCost according to the target
+/// Adjust the \p GlobalLSCost according to the target
/// paring capabilities and the layout of the slices.
/// \pre \p GlobalLSCost should account for at least as many loads as
/// there is in the slices in \p LoadedSlices.
@@ -10395,8 +13581,7 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
// Sort the slices so that elements that are likely to be next to each
// other in memory are next to each other in the list.
- std::sort(LoadedSlices.begin(), LoadedSlices.end(),
- [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
+ llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
});
@@ -10408,7 +13593,6 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
// Set the beginning of the pair.
First = Second) {
-
Second = &LoadedSlices[CurrSlice];
// If First is NULL, it means we start a new pair.
@@ -10444,7 +13628,7 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
}
}
-/// \brief Check the profitability of all involved LoadedSlice.
+/// Check the profitability of all involved LoadedSlice.
/// Currently, it is considered profitable if there is exactly two
/// involved slices (1) which are (2) next to each other in memory, and
/// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
@@ -10488,7 +13672,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
return OrigCost > GlobalSlicingCost;
}
-/// \brief If the given load, \p LI, is used only by trunc or trunc(lshr)
+/// If the given load, \p LI, is used only by trunc or trunc(lshr)
/// operations, split it in the various pieces being extracted.
///
/// This sort of thing is introduced by SROA.
@@ -10524,7 +13708,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
// Check if this is a trunc(lshr).
if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
isa<ConstantSDNode>(User->getOperand(1))) {
- Shift = cast<ConstantSDNode>(User->getOperand(1))->getZExtValue();
+ Shift = User->getConstantOperandVal(1);
User = *User->use_begin();
}
@@ -10539,7 +13723,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
// will be across several bytes. We do not support that.
unsigned Width = User->getValueSizeInBits(0);
if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
- return 0;
+ return false;
// Build the slice for this chain of computations.
LoadedSlice LS(User, LD, Shift, &DAG);
@@ -10576,7 +13760,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
LSIt != LSItEnd; ++LSIt) {
SDValue SliceInst = LSIt->loadSlice();
CombineTo(LSIt->Inst, SliceInst, true);
- if (SliceInst.getNode()->getOpcode() != ISD::LOAD)
+ if (SliceInst.getOpcode() != ISD::LOAD)
SliceInst = SliceInst.getOperand(0);
assert(SliceInst->getOpcode() == ISD::LOAD &&
"It takes more than a zext to get to the loaded slice!!");
@@ -10586,6 +13770,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
ArgChains);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
+ AddToWorklist(Chain.getNode());
return true;
}
@@ -10606,22 +13791,6 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
- // The store should be chained directly to the load or be an operand of a
- // tokenfactor.
- if (LD == Chain.getNode())
- ; // ok.
- else if (Chain->getOpcode() != ISD::TokenFactor)
- return Result; // Fail.
- else {
- bool isOk = false;
- for (const SDValue &ChainOp : Chain->op_values())
- if (ChainOp.getNode() == LD) {
- isOk = true;
- break;
- }
- if (!isOk) return Result;
- }
-
// This only handles simple types.
if (V.getValueType() != MVT::i16 &&
V.getValueType() != MVT::i32 &&
@@ -10658,12 +13827,29 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
// is aligned the same as the access width.
if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
+ // For narrowing to be valid, it must be the case that the load the
+ // immediately preceeding memory operation before the store.
+ if (LD == Chain.getNode())
+ ; // ok.
+ else if (Chain->getOpcode() == ISD::TokenFactor &&
+ SDValue(LD, 1).hasOneUse()) {
+ // LD has only 1 chain use so they are no indirect dependencies.
+ bool isOk = false;
+ for (const SDValue &ChainOp : Chain->op_values())
+ if (ChainOp.getNode() == LD) {
+ isOk = true;
+ break;
+ }
+ if (!isOk)
+ return Result;
+ } else
+ return Result; // Fail.
+
Result.first = MaskedBytes;
Result.second = NotMaskTZ/8;
return Result;
}
-
/// Check to see if IVal is something that provides a value as specified by
/// MaskInfo. If so, replace the specified store with a narrower store of
/// truncated IVal.
@@ -10718,12 +13904,12 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
++OpsNarrowed;
- return DAG.getStore(St->getChain(), SDLoc(St), IVal, Ptr,
- St->getPointerInfo().getWithOffset(StOffset),
- false, false, NewAlign).getNode();
+ return DAG
+ .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
+ St->getPointerInfo().getWithOffset(StOffset), NewAlign)
+ .getNode();
}
-
/// Look for sequence of load / op / store where op is one of 'or', 'xor', and
/// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
/// narrowing the load and store if it would end up being a win for performance
@@ -10826,19 +14012,16 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
Ptr.getValueType(), Ptr,
DAG.getConstant(PtrOff, SDLoc(LD),
Ptr.getValueType()));
- SDValue NewLD = DAG.getLoad(NewVT, SDLoc(N0),
- LD->getChain(), NewPtr,
- LD->getPointerInfo().getWithOffset(PtrOff),
- LD->isVolatile(), LD->isNonTemporal(),
- LD->isInvariant(), NewAlign,
- LD->getAAInfo());
+ SDValue NewLD =
+ DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
+ LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
+ LD->getMemOperand()->getFlags(), LD->getAAInfo());
SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
DAG.getConstant(NewImm, SDLoc(Value),
NewVT));
- SDValue NewST = DAG.getStore(Chain, SDLoc(N),
- NewVal, NewPtr,
- ST->getPointerInfo().getWithOffset(PtrOff),
- false, false, NewAlign);
+ SDValue NewST =
+ DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
+ ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
AddToWorklist(NewPtr.getNode());
AddToWorklist(NewLD.getNode());
@@ -10887,15 +14070,13 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
if (LDAlign < ABIAlign || STAlign < ABIAlign)
return SDValue();
- SDValue NewLD = DAG.getLoad(IntVT, SDLoc(Value),
- LD->getChain(), LD->getBasePtr(),
- LD->getPointerInfo(),
- false, false, false, LDAlign);
+ SDValue NewLD =
+ DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
+ LD->getPointerInfo(), LDAlign);
- SDValue NewST = DAG.getStore(NewLD.getValue(1), SDLoc(N),
- NewLD, ST->getBasePtr(),
- ST->getPointerInfo(),
- false, false, STAlign);
+ SDValue NewST =
+ DAG.getStore(NewLD.getValue(1), SDLoc(N), NewLD, ST->getBasePtr(),
+ ST->getPointerInfo(), STAlign);
AddToWorklist(NewLD.getNode());
AddToWorklist(NewST.getNode());
@@ -10908,96 +14089,6 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
return SDValue();
}
-namespace {
-/// Helper struct to parse and store a memory address as base + index + offset.
-/// We ignore sign extensions when it is safe to do so.
-/// The following two expressions are not equivalent. To differentiate we need
-/// to store whether there was a sign extension involved in the index
-/// computation.
-/// (load (i64 add (i64 copyfromreg %c)
-/// (i64 signextend (add (i8 load %index)
-/// (i8 1))))
-/// vs
-///
-/// (load (i64 add (i64 copyfromreg %c)
-/// (i64 signextend (i32 add (i32 signextend (i8 load %index))
-/// (i32 1)))))
-struct BaseIndexOffset {
- SDValue Base;
- SDValue Index;
- int64_t Offset;
- bool IsIndexSignExt;
-
- BaseIndexOffset() : Offset(0), IsIndexSignExt(false) {}
-
- BaseIndexOffset(SDValue Base, SDValue Index, int64_t Offset,
- bool IsIndexSignExt) :
- Base(Base), Index(Index), Offset(Offset), IsIndexSignExt(IsIndexSignExt) {}
-
- bool equalBaseIndex(const BaseIndexOffset &Other) {
- return Other.Base == Base && Other.Index == Index &&
- Other.IsIndexSignExt == IsIndexSignExt;
- }
-
- /// Parses tree in Ptr for base, index, offset addresses.
- static BaseIndexOffset match(SDValue Ptr) {
- bool IsIndexSignExt = false;
-
- // We only can pattern match BASE + INDEX + OFFSET. If Ptr is not an ADD
- // instruction, then it could be just the BASE or everything else we don't
- // know how to handle. Just use Ptr as BASE and give up.
- if (Ptr->getOpcode() != ISD::ADD)
- return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
-
- // We know that we have at least an ADD instruction. Try to pattern match
- // the simple case of BASE + OFFSET.
- if (isa<ConstantSDNode>(Ptr->getOperand(1))) {
- int64_t Offset = cast<ConstantSDNode>(Ptr->getOperand(1))->getSExtValue();
- return BaseIndexOffset(Ptr->getOperand(0), SDValue(), Offset,
- IsIndexSignExt);
- }
-
- // Inside a loop the current BASE pointer is calculated using an ADD and a
- // MUL instruction. In this case Ptr is the actual BASE pointer.
- // (i64 add (i64 %array_ptr)
- // (i64 mul (i64 %induction_var)
- // (i64 %element_size)))
- if (Ptr->getOperand(1)->getOpcode() == ISD::MUL)
- return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
-
- // Look at Base + Index + Offset cases.
- SDValue Base = Ptr->getOperand(0);
- SDValue IndexOffset = Ptr->getOperand(1);
-
- // Skip signextends.
- if (IndexOffset->getOpcode() == ISD::SIGN_EXTEND) {
- IndexOffset = IndexOffset->getOperand(0);
- IsIndexSignExt = true;
- }
-
- // Either the case of Base + Index (no offset) or something else.
- if (IndexOffset->getOpcode() != ISD::ADD)
- return BaseIndexOffset(Base, IndexOffset, 0, IsIndexSignExt);
-
- // Now we have the case of Base + Index + offset.
- SDValue Index = IndexOffset->getOperand(0);
- SDValue Offset = IndexOffset->getOperand(1);
-
- if (!isa<ConstantSDNode>(Offset))
- return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
-
- // Ignore signextends.
- if (Index->getOpcode() == ISD::SIGN_EXTEND) {
- Index = Index->getOperand(0);
- IsIndexSignExt = true;
- } else IsIndexSignExt = false;
-
- int64_t Off = cast<ConstantSDNode>(Offset)->getSExtValue();
- return BaseIndexOffset(Base, Index, Off, IsIndexSignExt);
- }
-};
-} // namespace
-
// This is a helper function for visitMUL to check the profitability
// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
// MulNode is the original multiply, AddNode is (add x, c1),
@@ -11022,7 +14113,6 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
// Walk all the users of the constant with which we're multiplying.
for (SDNode *Use : ConstNode->uses()) {
-
if (Use == MulNode) // This use is the one we're on right now. Skip it.
continue;
@@ -11063,7 +14153,7 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
// multiply (CONST * A) after we also do the same transformation
// to the "t2" instruction.
if (OtherOp->getOpcode() == ISD::ADD &&
- isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
+ DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
OtherOp->getOperand(0).getNode() == MulVar)
return true;
}
@@ -11073,83 +14163,118 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
return false;
}
-SDValue DAGCombiner::getMergedConstantVectorStore(SelectionDAG &DAG,
- SDLoc SL,
- ArrayRef<MemOpLink> Stores,
- SmallVectorImpl<SDValue> &Chains,
- EVT Ty) const {
- SmallVector<SDValue, 8> BuildVector;
+SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
+ unsigned NumStores) {
+ SmallVector<SDValue, 8> Chains;
+ SmallPtrSet<const SDNode *, 8> Visited;
+ SDLoc StoreDL(StoreNodes[0].MemNode);
+
+ for (unsigned i = 0; i < NumStores; ++i) {
+ Visited.insert(StoreNodes[i].MemNode);
+ }
- for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) {
- StoreSDNode *St = cast<StoreSDNode>(Stores[I].MemNode);
- Chains.push_back(St->getChain());
- BuildVector.push_back(St->getValue());
+ // don't include nodes that are children
+ for (unsigned i = 0; i < NumStores; ++i) {
+ if (Visited.count(StoreNodes[i].MemNode->getChain().getNode()) == 0)
+ Chains.push_back(StoreNodes[i].MemNode->getChain());
}
- return DAG.getNode(ISD::BUILD_VECTOR, SL, Ty, BuildVector);
+ assert(Chains.size() > 0 && "Chain should have generated a chain");
+ return DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, Chains);
}
bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
- SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT,
- unsigned NumStores, bool IsConstantSrc, bool UseVector) {
+ SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
+ bool IsConstantSrc, bool UseVector, bool UseTrunc) {
// Make sure we have something to merge.
if (NumStores < 2)
return false;
- int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8;
- LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
- unsigned LatestNodeUsed = 0;
-
- for (unsigned i=0; i < NumStores; ++i) {
- // Find a chain for the new wide-store operand. Notice that some
- // of the store nodes that we found may not be selected for inclusion
- // in the wide store. The chain we use needs to be the chain of the
- // latest store node which is *used* and replaced by the wide store.
- if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum)
- LatestNodeUsed = i;
- }
-
- SmallVector<SDValue, 8> Chains;
-
// The latest Node in the DAG.
- LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode;
SDLoc DL(StoreNodes[0].MemNode);
- SDValue StoredVal;
+ int64_t ElementSizeBits = MemVT.getStoreSizeInBits();
+ unsigned SizeInBits = NumStores * ElementSizeBits;
+ unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
+
+ EVT StoreTy;
if (UseVector) {
- bool IsVec = MemVT.isVector();
- unsigned Elts = NumStores;
- if (IsVec) {
- // When merging vector stores, get the total number of elements.
- Elts *= MemVT.getVectorNumElements();
- }
+ unsigned Elts = NumStores * NumMemElts;
// Get the type for the merged vector store.
- EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
- assert(TLI.isTypeLegal(Ty) && "Illegal vector store");
+ StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
+ } else
+ StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
+ SDValue StoredVal;
+ if (UseVector) {
if (IsConstantSrc) {
- StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Chains, Ty);
+ SmallVector<SDValue, 8> BuildVector;
+ for (unsigned I = 0; I != NumStores; ++I) {
+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
+ SDValue Val = St->getValue();
+ // If constant is of the wrong type, convert it now.
+ if (MemVT != Val.getValueType()) {
+ Val = peekThroughBitcasts(Val);
+ // Deal with constants of wrong size.
+ if (ElementSizeBits != Val.getValueSizeInBits()) {
+ EVT IntMemVT =
+ EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
+ if (isa<ConstantFPSDNode>(Val)) {
+ // Not clear how to truncate FP values.
+ return false;
+ } else if (auto *C = dyn_cast<ConstantSDNode>(Val))
+ Val = DAG.getConstant(C->getAPIntValue()
+ .zextOrTrunc(Val.getValueSizeInBits())
+ .zextOrTrunc(ElementSizeBits),
+ SDLoc(C), IntMemVT);
+ }
+ // Make sure correctly size type is the correct type.
+ Val = DAG.getBitcast(MemVT, Val);
+ }
+ BuildVector.push_back(Val);
+ }
+ StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
+ : ISD::BUILD_VECTOR,
+ DL, StoreTy, BuildVector);
} else {
SmallVector<SDValue, 8> Ops;
for (unsigned i = 0; i < NumStores; ++i) {
StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- SDValue Val = St->getValue();
- // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type.
- if (Val.getValueType() != MemVT)
- return false;
+ SDValue Val = peekThroughBitcasts(St->getValue());
+ // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
+ // type MemVT. If the underlying value is not the correct
+ // type, but it is an extraction of an appropriate vector we
+ // can recast Val to be of the correct type. This may require
+ // converting between EXTRACT_VECTOR_ELT and
+ // EXTRACT_SUBVECTOR.
+ if ((MemVT != Val.getValueType()) &&
+ (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
+ Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
+ EVT MemVTScalarTy = MemVT.getScalarType();
+ // We may need to add a bitcast here to get types to line up.
+ if (MemVTScalarTy != Val.getValueType().getScalarType()) {
+ Val = DAG.getBitcast(MemVT, Val);
+ } else {
+ unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
+ : ISD::EXTRACT_VECTOR_ELT;
+ SDValue Vec = Val.getOperand(0);
+ SDValue Idx = Val.getOperand(1);
+ Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
+ }
+ }
Ops.push_back(Val);
- Chains.push_back(St->getChain());
}
// Build the extracted vector elements back into a vector.
- StoredVal = DAG.getNode(IsVec ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR,
- DL, Ty, Ops); }
+ StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
+ : ISD::BUILD_VECTOR,
+ DL, StoreTy, Ops);
+ }
} else {
// We should always use a vector store when merging extracted vector
// elements, so this path implies a store of constants.
assert(IsConstantSrc && "Merged vector elements should use vector store");
- unsigned SizeInBits = NumStores * ElementSizeBytes * 8;
APInt StoreInt(SizeInBits, 0);
// Construct a single integer constant which is made of the smaller
@@ -11158,189 +14283,259 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
for (unsigned i = 0; i < NumStores; ++i) {
unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
- Chains.push_back(St->getChain());
SDValue Val = St->getValue();
- StoreInt <<= ElementSizeBytes * 8;
+ Val = peekThroughBitcasts(Val);
+ StoreInt <<= ElementSizeBits;
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
- StoreInt |= C->getAPIntValue().zext(SizeInBits);
+ StoreInt |= C->getAPIntValue()
+ .zextOrTrunc(ElementSizeBits)
+ .zextOrTrunc(SizeInBits);
} else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
- StoreInt |= C->getValueAPF().bitcastToAPInt().zext(SizeInBits);
+ StoreInt |= C->getValueAPF()
+ .bitcastToAPInt()
+ .zextOrTrunc(ElementSizeBits)
+ .zextOrTrunc(SizeInBits);
+ // If fp truncation is necessary give up for now.
+ if (MemVT.getSizeInBits() != ElementSizeBits)
+ return false;
} else {
llvm_unreachable("Invalid constant element type");
}
}
// Create the new Load and Store operations.
- EVT StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
}
- assert(!Chains.empty());
-
- SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
- SDValue NewStore = DAG.getStore(NewChain, DL, StoredVal,
- FirstInChain->getBasePtr(),
- FirstInChain->getPointerInfo(),
- false, false,
- FirstInChain->getAlignment());
-
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
- if (UseAA) {
- // Replace all merged stores with the new store.
- for (unsigned i = 0; i < NumStores; ++i)
- CombineTo(StoreNodes[i].MemNode, NewStore);
- } else {
- // Replace the last store with the new store.
- CombineTo(LatestOp, NewStore);
- // Erase all other stores.
- for (unsigned i = 0; i < NumStores; ++i) {
- if (StoreNodes[i].MemNode == LatestOp)
- continue;
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- // ReplaceAllUsesWith will replace all uses that existed when it was
- // called, but graph optimizations may cause new ones to appear. For
- // example, the case in pr14333 looks like
- //
- // St's chain -> St -> another store -> X
- //
- // And the only difference from St to the other store is the chain.
- // When we change it's chain to be St's chain they become identical,
- // get CSEed and the net result is that X is now a use of St.
- // Since we know that St is redundant, just iterate.
- while (!St->use_empty())
- DAG.ReplaceAllUsesWith(SDValue(St, 0), St->getChain());
- deleteAndRecombine(St);
- }
- }
-
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
+
+ // make sure we use trunc store if it's necessary to be legal.
+ SDValue NewStore;
+ if (!UseTrunc) {
+ NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(),
+ FirstInChain->getAlignment());
+ } else { // Must be realized as a trunc store
+ EVT LegalizedStoredValTy =
+ TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
+ unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
+ ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
+ SDValue ExtendedStoreVal =
+ DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
+ LegalizedStoredValTy);
+ NewStore = DAG.getTruncStore(
+ NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
+ FirstInChain->getAlignment(),
+ FirstInChain->getMemOperand()->getFlags());
+ }
+
+ // Replace all merged stores with the new store.
+ for (unsigned i = 0; i < NumStores; ++i)
+ CombineTo(StoreNodes[i].MemNode, NewStore);
+
+ AddToWorklist(NewChain.getNode());
return true;
}
-void DAGCombiner::getStoreMergeAndAliasCandidates(
- StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes,
- SmallVectorImpl<LSBaseSDNode*> &AliasLoadNodes) {
+void DAGCombiner::getStoreMergeCandidates(
+ StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
+ SDNode *&RootNode) {
// This holds the base pointer, index, and the offset in bytes from the base
// pointer.
- BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr());
+ BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
+ EVT MemVT = St->getMemoryVT();
+ SDValue Val = peekThroughBitcasts(St->getValue());
// We must have a base and an offset.
- if (!BasePtr.Base.getNode())
+ if (!BasePtr.getBase().getNode())
return;
// Do not handle stores to undef base pointers.
- if (BasePtr.Base.getOpcode() == ISD::UNDEF)
- return;
-
- // Walk up the chain and look for nodes with offsets from the same
- // base pointer. Stop when reaching an instruction with a different kind
- // or instruction which has a different base pointer.
- EVT MemVT = St->getMemoryVT();
- unsigned Seq = 0;
- StoreSDNode *Index = St;
-
-
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
-
- if (UseAA) {
- // Look at other users of the same chain. Stores on the same chain do not
- // alias. If combiner-aa is enabled, non-aliasing stores are canonicalized
- // to be on the same chain, so don't bother looking at adjacent chains.
-
- SDValue Chain = St->getChain();
- for (auto I = Chain->use_begin(), E = Chain->use_end(); I != E; ++I) {
- if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) {
- if (I.getOperandNo() != 0)
- continue;
-
- if (OtherST->isVolatile() || OtherST->isIndexed())
- continue;
-
- if (OtherST->getMemoryVT() != MemVT)
- continue;
-
- BaseIndexOffset Ptr = BaseIndexOffset::match(OtherST->getBasePtr());
-
- if (Ptr.equalBaseIndex(BasePtr))
- StoreNodes.push_back(MemOpLink(OtherST, Ptr.Offset, Seq++));
- }
- }
-
+ if (BasePtr.getBase().isUndef())
return;
- }
-
- while (Index) {
- // If the chain has more than one use, then we can't reorder the mem ops.
- if (Index != St && !SDValue(Index, 0)->hasOneUse())
- break;
-
- // Find the base pointer and offset for this memory node.
- BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr());
-
- // Check that the base pointer is the same as the original one.
- if (!Ptr.equalBaseIndex(BasePtr))
- break;
+ bool IsConstantSrc = isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val);
+ bool IsExtractVecSrc = (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
+ Val.getOpcode() == ISD::EXTRACT_SUBVECTOR);
+ bool IsLoadSrc = isa<LoadSDNode>(Val);
+ BaseIndexOffset LBasePtr;
+ // Match on loadbaseptr if relevant.
+ EVT LoadVT;
+ if (IsLoadSrc) {
+ auto *Ld = cast<LoadSDNode>(Val);
+ LBasePtr = BaseIndexOffset::match(Ld, DAG);
+ LoadVT = Ld->getMemoryVT();
+ // Load and store should be the same type.
+ if (MemVT != LoadVT)
+ return;
+ // Loads must only have one use.
+ if (!Ld->hasNUsesOfValue(1, 0))
+ return;
// The memory operands must not be volatile.
- if (Index->isVolatile() || Index->isIndexed())
- break;
+ if (Ld->isVolatile() || Ld->isIndexed())
+ return;
+ }
+ auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
+ int64_t &Offset) -> bool {
+ if (Other->isVolatile() || Other->isIndexed())
+ return false;
+ SDValue Val = peekThroughBitcasts(Other->getValue());
+ // Allow merging constants of different types as integers.
+ bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
+ : Other->getMemoryVT() != MemVT;
+ if (IsLoadSrc) {
+ if (NoTypeMatch)
+ return false;
+ // The Load's Base Ptr must also match
+ if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(Val)) {
+ auto LPtr = BaseIndexOffset::match(OtherLd, DAG);
+ if (LoadVT != OtherLd->getMemoryVT())
+ return false;
+ // Loads must only have one use.
+ if (!OtherLd->hasNUsesOfValue(1, 0))
+ return false;
+ // The memory operands must not be volatile.
+ if (OtherLd->isVolatile() || OtherLd->isIndexed())
+ return false;
+ if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
+ return false;
+ } else
+ return false;
+ }
+ if (IsConstantSrc) {
+ if (NoTypeMatch)
+ return false;
+ if (!(isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val)))
+ return false;
+ }
+ if (IsExtractVecSrc) {
+ // Do not merge truncated stores here.
+ if (Other->isTruncatingStore())
+ return false;
+ if (!MemVT.bitsEq(Val.getValueType()))
+ return false;
+ if (Val.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
+ Val.getOpcode() != ISD::EXTRACT_SUBVECTOR)
+ return false;
+ }
+ Ptr = BaseIndexOffset::match(Other, DAG);
+ return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
+ };
- // No truncation.
- if (StoreSDNode *St = dyn_cast<StoreSDNode>(Index))
- if (St->isTruncatingStore())
- break;
+ // We looking for a root node which is an ancestor to all mergable
+ // stores. We search up through a load, to our root and then down
+ // through all children. For instance we will find Store{1,2,3} if
+ // St is Store1, Store2. or Store3 where the root is not a load
+ // which always true for nonvolatile ops. TODO: Expand
+ // the search to find all valid candidates through multiple layers of loads.
+ //
+ // Root
+ // |-------|-------|
+ // Load Load Store3
+ // | |
+ // Store1 Store2
+ //
+ // FIXME: We should be able to climb and
+ // descend TokenFactors to find candidates as well.
+
+ RootNode = St->getChain().getNode();
+
+ if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
+ RootNode = Ldn->getChain().getNode();
+ for (auto I = RootNode->use_begin(), E = RootNode->use_end(); I != E; ++I)
+ if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) // walk down chain
+ for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
+ if (I2.getOperandNo() == 0)
+ if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I2)) {
+ BaseIndexOffset Ptr;
+ int64_t PtrDiff;
+ if (CandidateMatch(OtherST, Ptr, PtrDiff))
+ StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
+ }
+ } else
+ for (auto I = RootNode->use_begin(), E = RootNode->use_end(); I != E; ++I)
+ if (I.getOperandNo() == 0)
+ if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) {
+ BaseIndexOffset Ptr;
+ int64_t PtrDiff;
+ if (CandidateMatch(OtherST, Ptr, PtrDiff))
+ StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
+ }
+}
- // The stored memory type must be the same.
- if (Index->getMemoryVT() != MemVT)
- break;
+// We need to check that merging these stores does not cause a loop in
+// the DAG. Any store candidate may depend on another candidate
+// indirectly through its operand (we already consider dependencies
+// through the chain). Check in parallel by searching up from
+// non-chain operands of candidates.
+bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
+ SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
+ SDNode *RootNode) {
+ // FIXME: We should be able to truncate a full search of
+ // predecessors by doing a BFS and keeping tabs the originating
+ // stores from which worklist nodes come from in a similar way to
+ // TokenFactor simplfication.
- // We do not allow under-aligned stores in order to prevent
- // overriding stores. NOTE: this is a bad hack. Alignment SHOULD
- // be irrelevant here; what MATTERS is that we not move memory
- // operations that potentially overlap past each-other.
- if (Index->getAlignment() < MemVT.getStoreSize())
- break;
+ SmallPtrSet<const SDNode *, 32> Visited;
+ SmallVector<const SDNode *, 8> Worklist;
- // We found a potential memory operand to merge.
- StoreNodes.push_back(MemOpLink(Index, Ptr.Offset, Seq++));
-
- // Find the next memory operand in the chain. If the next operand in the
- // chain is a store then move up and continue the scan with the next
- // memory operand. If the next operand is a load save it and use alias
- // information to check if it interferes with anything.
- SDNode *NextInChain = Index->getChain().getNode();
- while (1) {
- if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) {
- // We found a store node. Use it for the next iteration.
- Index = STn;
- break;
- } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) {
- if (Ldn->isVolatile()) {
- Index = nullptr;
- break;
- }
+ // RootNode is a predecessor to all candidates so we need not search
+ // past it. Add RootNode (peeking through TokenFactors). Do not count
+ // these towards size check.
- // Save the load node for later. Continue the scan.
- AliasLoadNodes.push_back(Ldn);
- NextInChain = Ldn->getChain().getNode();
- continue;
- } else {
- Index = nullptr;
- break;
- }
- }
- }
+ Worklist.push_back(RootNode);
+ while (!Worklist.empty()) {
+ auto N = Worklist.pop_back_val();
+ if (!Visited.insert(N).second)
+ continue; // Already present in Visited.
+ if (N->getOpcode() == ISD::TokenFactor) {
+ for (SDValue Op : N->ops())
+ Worklist.push_back(Op.getNode());
+ }
+ }
+
+ // Don't count pruning nodes towards max.
+ unsigned int Max = 1024 + Visited.size();
+ // Search Ops of store candidates.
+ for (unsigned i = 0; i < NumStores; ++i) {
+ SDNode *N = StoreNodes[i].MemNode;
+ // Of the 4 Store Operands:
+ // * Chain (Op 0) -> We have already considered these
+ // in candidate selection and can be
+ // safely ignored
+ // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
+ // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
+ // but aren't necessarily fromt the same base node, so
+ // cycles possible (e.g. via indexed store).
+ // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
+ // non-indexed stores). Not constant on all targets (e.g. ARM)
+ // and so can participate in a cycle.
+ for (unsigned j = 1; j < N->getNumOperands(); ++j)
+ Worklist.push_back(N->getOperand(j).getNode());
+ }
+ // Search through DAG. We can stop early if we find a store node.
+ for (unsigned i = 0; i < NumStores; ++i)
+ if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
+ Max))
+ return false;
+ return true;
}
-bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
+bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
if (OptLevel == CodeGenOpt::None)
return false;
EVT MemVT = St->getMemoryVT();
- int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8;
- bool NoVectors = DAG.getMachineFunction().getFunction()->hasFnAttribute(
+ int64_t ElementSizeBytes = MemVT.getStoreSize();
+ unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
+
+ if (MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
+ return false;
+
+ bool NoVectors = DAG.getMachineFunction().getFunction().hasFnAttribute(
Attribute::NoImplicitFloat);
// This function cannot currently deal with non-byte-sized memory sizes.
@@ -11352,7 +14547,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
// Perform an early exit check. Do not bother looking at stored values that
// are not constants, loads, or extracted vector elements.
- SDValue StoredVal = St->getValue();
+ SDValue StoredVal = peekThroughBitcasts(St->getValue());
bool IsLoadSrc = isa<LoadSDNode>(StoredVal);
bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) ||
isa<ConstantFPSDNode>(StoredVal);
@@ -11362,381 +14557,495 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc)
return false;
- // Don't merge vectors into wider vectors if the source data comes from loads.
- // TODO: This restriction can be lifted by using logic similar to the
- // ExtractVecSrc case.
- if (MemVT.isVector() && IsLoadSrc)
- return false;
-
- // Only look at ends of store sequences.
- SDValue Chain = SDValue(St, 0);
- if (Chain->hasOneUse() && Chain->use_begin()->getOpcode() == ISD::STORE)
- return false;
-
- // Save the LoadSDNodes that we find in the chain.
- // We need to make sure that these nodes do not interfere with
- // any of the store nodes.
- SmallVector<LSBaseSDNode*, 8> AliasLoadNodes;
-
- // Save the StoreSDNodes that we find in the chain.
SmallVector<MemOpLink, 8> StoreNodes;
-
- getStoreMergeAndAliasCandidates(St, StoreNodes, AliasLoadNodes);
+ SDNode *RootNode;
+ // Find potential store merge candidates by searching through chain sub-DAG
+ getStoreMergeCandidates(St, StoreNodes, RootNode);
// Check if there is anything to merge.
if (StoreNodes.size() < 2)
return false;
// Sort the memory operands according to their distance from the
- // base pointer. As a secondary criteria: make sure stores coming
- // later in the code come first in the list. This is important for
- // the non-UseAA case, because we're merging stores into the FINAL
- // store along a chain which potentially contains aliasing stores.
- // Thus, if there are multiple stores to the same address, the last
- // one can be considered for merging but not the others.
- std::sort(StoreNodes.begin(), StoreNodes.end(),
- [](MemOpLink LHS, MemOpLink RHS) {
- return LHS.OffsetFromBase < RHS.OffsetFromBase ||
- (LHS.OffsetFromBase == RHS.OffsetFromBase &&
- LHS.SequenceNum < RHS.SequenceNum);
+ // base pointer.
+ llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
+ return LHS.OffsetFromBase < RHS.OffsetFromBase;
});
- // Scan the memory operations on the chain and find the first non-consecutive
- // store memory address.
- unsigned LastConsecutiveStore = 0;
- int64_t StartAddress = StoreNodes[0].OffsetFromBase;
- for (unsigned i = 0, e = StoreNodes.size(); i < e; ++i) {
-
+ // Store Merge attempts to merge the lowest stores. This generally
+ // works out as if successful, as the remaining stores are checked
+ // after the first collection of stores is merged. However, in the
+ // case that a non-mergeable store is found first, e.g., {p[-2],
+ // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
+ // mergeable cases. To prevent this, we prune such stores from the
+ // front of StoreNodes here.
+
+ bool RV = false;
+ while (StoreNodes.size() > 1) {
+ unsigned StartIdx = 0;
+ while ((StartIdx + 1 < StoreNodes.size()) &&
+ StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
+ StoreNodes[StartIdx + 1].OffsetFromBase)
+ ++StartIdx;
+
+ // Bail if we don't have enough candidates to merge.
+ if (StartIdx + 1 >= StoreNodes.size())
+ return RV;
+
+ if (StartIdx)
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
+
+ // Scan the memory operations on the chain and find the first
+ // non-consecutive store memory address.
+ unsigned NumConsecutiveStores = 1;
+ int64_t StartAddress = StoreNodes[0].OffsetFromBase;
// Check that the addresses are consecutive starting from the second
// element in the list of stores.
- if (i > 0) {
+ for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
if (CurrAddress - StartAddress != (ElementSizeBytes * i))
break;
+ NumConsecutiveStores = i + 1;
}
- // Check if this store interferes with any of the loads that we found.
- // If we find a load that alias with this store. Stop the sequence.
- if (std::any_of(AliasLoadNodes.begin(), AliasLoadNodes.end(),
- [&](LSBaseSDNode* Ldn) {
- return isAlias(Ldn, StoreNodes[i].MemNode);
- }))
- break;
+ if (NumConsecutiveStores < 2) {
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumConsecutiveStores);
+ continue;
+ }
- // Mark this node as useful.
- LastConsecutiveStore = i;
- }
+ // The node with the lowest store address.
+ LLVMContext &Context = *DAG.getContext();
+ const DataLayout &DL = DAG.getDataLayout();
- // The node with the lowest store address.
- LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
- unsigned FirstStoreAS = FirstInChain->getAddressSpace();
- unsigned FirstStoreAlign = FirstInChain->getAlignment();
- LLVMContext &Context = *DAG.getContext();
- const DataLayout &DL = DAG.getDataLayout();
-
- // Store the constants into memory as one consecutive store.
- if (IsConstantSrc) {
- unsigned LastLegalType = 0;
- unsigned LastLegalVectorType = 0;
- bool NonZero = false;
- for (unsigned i=0; i<LastConsecutiveStore+1; ++i) {
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- SDValue StoredVal = St->getValue();
-
- if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) {
- NonZero |= !C->isNullValue();
- } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) {
- NonZero |= !C->getConstantFPValue()->isNullValue();
- } else {
- // Non-constant.
- break;
- }
+ // Store the constants into memory as one consecutive store.
+ if (IsConstantSrc) {
+ while (NumConsecutiveStores >= 2) {
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+ unsigned FirstStoreAlign = FirstInChain->getAlignment();
+ unsigned LastLegalType = 1;
+ unsigned LastLegalVectorType = 1;
+ bool LastIntegerTrunc = false;
+ bool NonZero = false;
+ unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+ StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
+ SDValue StoredVal = ST->getValue();
+ bool IsElementZero = false;
+ if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
+ IsElementZero = C->isNullValue();
+ else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
+ IsElementZero = C->getConstantFPValue()->isNullValue();
+ if (IsElementZero) {
+ if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
+ FirstZeroAfterNonZero = i;
+ }
+ NonZero |= !IsElementZero;
+
+ // Find a legal type for the constant store.
+ unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
+ EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
+ bool IsFast = false;
+
+ // Break early when size is too large to be legal.
+ if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
+ break;
+
+ if (TLI.isTypeLegal(StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
+ IsFast) {
+ LastIntegerTrunc = false;
+ LastLegalType = i + 1;
+ // Or check whether a truncstore is legal.
+ } else if (TLI.getTypeAction(Context, StoreTy) ==
+ TargetLowering::TypePromoteInteger) {
+ EVT LegalizedStoredValTy =
+ TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
+ if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
+ IsFast) {
+ LastIntegerTrunc = true;
+ LastLegalType = i + 1;
+ }
+ }
- // Find a legal type for the constant store.
- unsigned SizeInBits = (i+1) * ElementSizeBytes * 8;
- EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
- bool IsFast;
- if (TLI.isTypeLegal(StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
- FirstStoreAlign, &IsFast) && IsFast) {
- LastLegalType = i+1;
- // Or check whether a truncstore is legal.
- } else if (TLI.getTypeAction(Context, StoreTy) ==
- TargetLowering::TypePromoteInteger) {
- EVT LegalizedStoredValueTy =
- TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
- if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
- FirstStoreAS, FirstStoreAlign, &IsFast) &&
- IsFast) {
- LastLegalType = i + 1;
+ // We only use vectors if the constant is known to be zero or the
+ // target allows it and the function is not marked with the
+ // noimplicitfloat attribute.
+ if ((!NonZero ||
+ TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
+ !NoVectors) {
+ // Find a legal type for the vector store.
+ unsigned Elts = (i + 1) * NumMemElts;
+ EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
+ if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
+ TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
+ IsFast)
+ LastLegalVectorType = i + 1;
+ }
}
- }
- // We only use vectors if the constant is known to be zero or the target
- // allows it and the function is not marked with the noimplicitfloat
- // attribute.
- if ((!NonZero || TLI.storeOfVectorConstantIsCheap(MemVT, i+1,
- FirstStoreAS)) &&
- !NoVectors) {
- // Find a legal type for the vector store.
- EVT Ty = EVT::getVectorVT(Context, MemVT, i+1);
- if (TLI.isTypeLegal(Ty) &&
- TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
- FirstStoreAlign, &IsFast) && IsFast)
- LastLegalVectorType = i + 1;
- }
- }
+ bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors;
+ unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
+
+ // Check if we found a legal integer type that creates a meaningful
+ // merge.
+ if (NumElem < 2) {
+ // We know that candidate stores are in order and of correct
+ // shape. While there is no mergeable sequence from the
+ // beginning one may start later in the sequence. The only
+ // reason a merge of size N could have failed where another of
+ // the same size would not have, is if the alignment has
+ // improved or we've dropped a non-zero value. Drop as many
+ // candidates as we can here.
+ unsigned NumSkip = 1;
+ while (
+ (NumSkip < NumConsecutiveStores) &&
+ (NumSkip < FirstZeroAfterNonZero) &&
+ (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
+ NumSkip++;
+
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
+ NumConsecutiveStores -= NumSkip;
+ continue;
+ }
- // Check if we found a legal integer type to store.
- if (LastLegalType == 0 && LastLegalVectorType == 0)
- return false;
+ // Check that we can merge these candidates without causing a cycle.
+ if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
+ RootNode)) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ NumConsecutiveStores -= NumElem;
+ continue;
+ }
- bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors;
- unsigned NumElem = UseVector ? LastLegalVectorType : LastLegalType;
-
- return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
- true, UseVector);
- }
-
- // When extracting multiple vector elements, try to store them
- // in one vector store rather than a sequence of scalar stores.
- if (IsExtractVecSrc) {
- unsigned NumStoresToMerge = 0;
- bool IsVec = MemVT.isVector();
- for (unsigned i = 0; i < LastConsecutiveStore + 1; ++i) {
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- unsigned StoreValOpcode = St->getValue().getOpcode();
- // This restriction could be loosened.
- // Bail out if any stored values are not elements extracted from a vector.
- // It should be possible to handle mixed sources, but load sources need
- // more careful handling (see the block of code below that handles
- // consecutive loads).
- if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT &&
- StoreValOpcode != ISD::EXTRACT_SUBVECTOR)
- return false;
+ RV |= MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, true,
+ UseVector, LastIntegerTrunc);
- // Find a legal type for the vector store.
- unsigned Elts = i + 1;
- if (IsVec) {
- // When merging vector stores, get the total number of elements.
- Elts *= MemVT.getVectorNumElements();
+ // Remove merged stores for next iteration.
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ NumConsecutiveStores -= NumElem;
}
- EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
- bool IsFast;
- if (TLI.isTypeLegal(Ty) &&
- TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
- FirstStoreAlign, &IsFast) && IsFast)
- NumStoresToMerge = i + 1;
+ continue;
}
- return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStoresToMerge,
- false, true);
- }
+ // When extracting multiple vector elements, try to store them
+ // in one vector store rather than a sequence of scalar stores.
+ if (IsExtractVecSrc) {
+ // Loop on Consecutive Stores on success.
+ while (NumConsecutiveStores >= 2) {
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+ unsigned FirstStoreAlign = FirstInChain->getAlignment();
+ unsigned NumStoresToMerge = 1;
+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+ // Find a legal type for the vector store.
+ unsigned Elts = (i + 1) * NumMemElts;
+ EVT Ty =
+ EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
+ bool IsFast;
+
+ // Break early when size is too large to be legal.
+ if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
+ break;
- // Below we handle the case of multiple consecutive stores that
- // come from multiple consecutive loads. We merge them into a single
- // wide load and a single wide store.
+ if (TLI.isTypeLegal(Ty) &&
+ TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
+ IsFast)
+ NumStoresToMerge = i + 1;
+ }
- // Look for load nodes which are used by the stored values.
- SmallVector<MemOpLink, 8> LoadNodes;
+ // Check if we found a legal integer type creating a meaningful
+ // merge.
+ if (NumStoresToMerge < 2) {
+ // We know that candidate stores are in order and of correct
+ // shape. While there is no mergeable sequence from the
+ // beginning one may start later in the sequence. The only
+ // reason a merge of size N could have failed where another of
+ // the same size would not have, is if the alignment has
+ // improved. Drop as many candidates as we can here.
+ unsigned NumSkip = 1;
+ while (
+ (NumSkip < NumConsecutiveStores) &&
+ (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
+ NumSkip++;
+
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
+ NumConsecutiveStores -= NumSkip;
+ continue;
+ }
- // Find acceptable loads. Loads need to have the same chain (token factor),
- // must not be zext, volatile, indexed, and they must be consecutive.
- BaseIndexOffset LdBasePtr;
- for (unsigned i=0; i<LastConsecutiveStore+1; ++i) {
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- LoadSDNode *Ld = dyn_cast<LoadSDNode>(St->getValue());
- if (!Ld) break;
+ // Check that we can merge these candidates without causing a cycle.
+ if (!checkMergeStoreCandidatesForDependencies(
+ StoreNodes, NumStoresToMerge, RootNode)) {
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumStoresToMerge);
+ NumConsecutiveStores -= NumStoresToMerge;
+ continue;
+ }
- // Loads must only have one use.
- if (!Ld->hasNUsesOfValue(1, 0))
- break;
+ RV |= MergeStoresOfConstantsOrVecElts(
+ StoreNodes, MemVT, NumStoresToMerge, false, true, false);
- // The memory operands must not be volatile.
- if (Ld->isVolatile() || Ld->isIndexed())
- break;
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumStoresToMerge);
+ NumConsecutiveStores -= NumStoresToMerge;
+ }
+ continue;
+ }
- // We do not accept ext loads.
- if (Ld->getExtensionType() != ISD::NON_EXTLOAD)
- break;
+ // Below we handle the case of multiple consecutive stores that
+ // come from multiple consecutive loads. We merge them into a single
+ // wide load and a single wide store.
- // The stored memory type must be the same.
- if (Ld->getMemoryVT() != MemVT)
- break;
+ // Look for load nodes which are used by the stored values.
+ SmallVector<MemOpLink, 8> LoadNodes;
- BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld->getBasePtr());
- // If this is not the first ptr that we check.
- if (LdBasePtr.Base.getNode()) {
- // The base ptr must be the same.
- if (!LdPtr.equalBaseIndex(LdBasePtr))
- break;
- } else {
- // Check that all other base pointers are the same as this one.
- LdBasePtr = LdPtr;
- }
+ // Find acceptable loads. Loads need to have the same chain (token factor),
+ // must not be zext, volatile, indexed, and they must be consecutive.
+ BaseIndexOffset LdBasePtr;
- // We found a potential memory operand to merge.
- LoadNodes.push_back(MemOpLink(Ld, LdPtr.Offset, 0));
- }
+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
+ SDValue Val = peekThroughBitcasts(St->getValue());
+ LoadSDNode *Ld = cast<LoadSDNode>(Val);
+
+ BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
+ // If this is not the first ptr that we check.
+ int64_t LdOffset = 0;
+ if (LdBasePtr.getBase().getNode()) {
+ // The base ptr must be the same.
+ if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
+ break;
+ } else {
+ // Check that all other base pointers are the same as this one.
+ LdBasePtr = LdPtr;
+ }
- if (LoadNodes.size() < 2)
- return false;
+ // We found a potential memory operand to merge.
+ LoadNodes.push_back(MemOpLink(Ld, LdOffset));
+ }
- // If we have load/store pair instructions and we only have two values,
- // don't bother.
- unsigned RequiredAlignment;
- if (LoadNodes.size() == 2 && TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
- St->getAlignment() >= RequiredAlignment)
- return false;
+ while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
+ // If we have load/store pair instructions and we only have two values,
+ // don't bother merging.
+ unsigned RequiredAlignment;
+ if (LoadNodes.size() == 2 &&
+ TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
+ StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
+ LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
+ break;
+ }
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+ unsigned FirstStoreAlign = FirstInChain->getAlignment();
+ LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
+ unsigned FirstLoadAS = FirstLoad->getAddressSpace();
+ unsigned FirstLoadAlign = FirstLoad->getAlignment();
+
+ // Scan the memory operations on the chain and find the first
+ // non-consecutive load memory address. These variables hold the index in
+ // the store node array.
+
+ unsigned LastConsecutiveLoad = 1;
+
+ // This variable refers to the size and not index in the array.
+ unsigned LastLegalVectorType = 1;
+ unsigned LastLegalIntegerType = 1;
+ bool isDereferenceable = true;
+ bool DoIntegerTruncate = false;
+ StartAddress = LoadNodes[0].OffsetFromBase;
+ SDValue FirstChain = FirstLoad->getChain();
+ for (unsigned i = 1; i < LoadNodes.size(); ++i) {
+ // All loads must share the same chain.
+ if (LoadNodes[i].MemNode->getChain() != FirstChain)
+ break;
- LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
- unsigned FirstLoadAS = FirstLoad->getAddressSpace();
- unsigned FirstLoadAlign = FirstLoad->getAlignment();
-
- // Scan the memory operations on the chain and find the first non-consecutive
- // load memory address. These variables hold the index in the store node
- // array.
- unsigned LastConsecutiveLoad = 0;
- // This variable refers to the size and not index in the array.
- unsigned LastLegalVectorType = 0;
- unsigned LastLegalIntegerType = 0;
- StartAddress = LoadNodes[0].OffsetFromBase;
- SDValue FirstChain = FirstLoad->getChain();
- for (unsigned i = 1; i < LoadNodes.size(); ++i) {
- // All loads must share the same chain.
- if (LoadNodes[i].MemNode->getChain() != FirstChain)
- break;
+ int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
+ if (CurrAddress - StartAddress != (ElementSizeBytes * i))
+ break;
+ LastConsecutiveLoad = i;
- int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
- if (CurrAddress - StartAddress != (ElementSizeBytes * i))
- break;
- LastConsecutiveLoad = i;
- // Find a legal type for the vector store.
- EVT StoreTy = EVT::getVectorVT(Context, MemVT, i+1);
- bool IsFastSt, IsFastLd;
- if (TLI.isTypeLegal(StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
- FirstStoreAlign, &IsFastSt) && IsFastSt &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
- FirstLoadAlign, &IsFastLd) && IsFastLd) {
- LastLegalVectorType = i + 1;
- }
-
- // Find a legal type for the integer store.
- unsigned SizeInBits = (i+1) * ElementSizeBytes * 8;
- StoreTy = EVT::getIntegerVT(Context, SizeInBits);
- if (TLI.isTypeLegal(StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
- FirstStoreAlign, &IsFastSt) && IsFastSt &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
- FirstLoadAlign, &IsFastLd) && IsFastLd)
- LastLegalIntegerType = i + 1;
- // Or check whether a truncstore and extload is legal.
- else if (TLI.getTypeAction(Context, StoreTy) ==
- TargetLowering::TypePromoteInteger) {
- EVT LegalizedStoredValueTy =
- TLI.getTypeToTransformTo(Context, StoreTy);
- if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) &&
- TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValueTy, StoreTy) &&
- TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy, StoreTy) &&
- TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
- FirstStoreAS, FirstStoreAlign, &IsFastSt) &&
- IsFastSt &&
- TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
- FirstLoadAS, FirstLoadAlign, &IsFastLd) &&
- IsFastLd)
- LastLegalIntegerType = i+1;
- }
- }
-
- // Only use vector types if the vector type is larger than the integer type.
- // If they are the same, use integers.
- bool UseVectorTy = LastLegalVectorType > LastLegalIntegerType && !NoVectors;
- unsigned LastLegalType = std::max(LastLegalVectorType, LastLegalIntegerType);
-
- // We add +1 here because the LastXXX variables refer to location while
- // the NumElem refers to array/index size.
- unsigned NumElem = std::min(LastConsecutiveStore, LastConsecutiveLoad) + 1;
- NumElem = std::min(LastLegalType, NumElem);
-
- if (NumElem < 2)
- return false;
+ if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
+ isDereferenceable = false;
- // Collect the chains from all merged stores.
- SmallVector<SDValue, 8> MergeStoreChains;
- MergeStoreChains.push_back(StoreNodes[0].MemNode->getChain());
+ // Find a legal type for the vector store.
+ unsigned Elts = (i + 1) * NumMemElts;
+ EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
- // The latest Node in the DAG.
- unsigned LatestNodeUsed = 0;
- for (unsigned i=1; i<NumElem; ++i) {
- // Find a chain for the new wide-store operand. Notice that some
- // of the store nodes that we found may not be selected for inclusion
- // in the wide store. The chain we use needs to be the chain of the
- // latest store node which is *used* and replaced by the wide store.
- if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum)
- LatestNodeUsed = i;
+ // Break early when size is too large to be legal.
+ if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
+ break;
- MergeStoreChains.push_back(StoreNodes[i].MemNode->getChain());
- }
+ bool IsFastSt, IsFastLd;
+ if (TLI.isTypeLegal(StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFastSt) &&
+ IsFastSt &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
+ FirstLoadAlign, &IsFastLd) &&
+ IsFastLd) {
+ LastLegalVectorType = i + 1;
+ }
- LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode;
+ // Find a legal type for the integer store.
+ unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
+ StoreTy = EVT::getIntegerVT(Context, SizeInBits);
+ if (TLI.isTypeLegal(StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFastSt) &&
+ IsFastSt &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
+ FirstLoadAlign, &IsFastLd) &&
+ IsFastLd) {
+ LastLegalIntegerType = i + 1;
+ DoIntegerTruncate = false;
+ // Or check whether a truncstore and extload is legal.
+ } else if (TLI.getTypeAction(Context, StoreTy) ==
+ TargetLowering::TypePromoteInteger) {
+ EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
+ if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
+ TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy,
+ StoreTy) &&
+ TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy,
+ StoreTy) &&
+ TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFastSt) &&
+ IsFastSt &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
+ FirstLoadAlign, &IsFastLd) &&
+ IsFastLd) {
+ LastLegalIntegerType = i + 1;
+ DoIntegerTruncate = true;
+ }
+ }
+ }
- // Find if it is better to use vectors or integers to load and store
- // to memory.
- EVT JointMemOpVT;
- if (UseVectorTy) {
- JointMemOpVT = EVT::getVectorVT(Context, MemVT, NumElem);
- } else {
- unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
- JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
- }
+ // Only use vector types if the vector type is larger than the integer
+ // type. If they are the same, use integers.
+ bool UseVectorTy =
+ LastLegalVectorType > LastLegalIntegerType && !NoVectors;
+ unsigned LastLegalType =
+ std::max(LastLegalVectorType, LastLegalIntegerType);
+
+ // We add +1 here because the LastXXX variables refer to location while
+ // the NumElem refers to array/index size.
+ unsigned NumElem =
+ std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
+ NumElem = std::min(LastLegalType, NumElem);
+
+ if (NumElem < 2) {
+ // We know that candidate stores are in order and of correct
+ // shape. While there is no mergeable sequence from the
+ // beginning one may start later in the sequence. The only
+ // reason a merge of size N could have failed where another of
+ // the same size would not have is if the alignment or either
+ // the load or store has improved. Drop as many candidates as we
+ // can here.
+ unsigned NumSkip = 1;
+ while ((NumSkip < LoadNodes.size()) &&
+ (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) &&
+ (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
+ NumSkip++;
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
+ LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
+ NumConsecutiveStores -= NumSkip;
+ continue;
+ }
- SDLoc LoadDL(LoadNodes[0].MemNode);
- SDLoc StoreDL(StoreNodes[0].MemNode);
+ // Check that we can merge these candidates without causing a cycle.
+ if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
+ RootNode)) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
+ NumConsecutiveStores -= NumElem;
+ continue;
+ }
- // The merged loads are required to have the same incoming chain, so
- // using the first's chain is acceptable.
- SDValue NewLoad = DAG.getLoad(
- JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
- FirstLoad->getPointerInfo(), false, false, false, FirstLoadAlign);
+ // Find if it is better to use vectors or integers to load and store
+ // to memory.
+ EVT JointMemOpVT;
+ if (UseVectorTy) {
+ // Find a legal type for the vector store.
+ unsigned Elts = NumElem * NumMemElts;
+ JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
+ } else {
+ unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
+ JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
+ }
- SDValue NewStoreChain =
- DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, MergeStoreChains);
+ SDLoc LoadDL(LoadNodes[0].MemNode);
+ SDLoc StoreDL(StoreNodes[0].MemNode);
+
+ // The merged loads are required to have the same incoming chain, so
+ // using the first's chain is acceptable.
+
+ SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
+ AddToWorklist(NewStoreChain.getNode());
+
+ MachineMemOperand::Flags MMOFlags =
+ isDereferenceable ? MachineMemOperand::MODereferenceable
+ : MachineMemOperand::MONone;
+
+ SDValue NewLoad, NewStore;
+ if (UseVectorTy || !DoIntegerTruncate) {
+ NewLoad =
+ DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(),
+ FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(),
+ FirstLoadAlign, MMOFlags);
+ NewStore = DAG.getStore(
+ NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(), FirstStoreAlign);
+ } else { // This must be the truncstore/extload case
+ EVT ExtendedTy =
+ TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
+ NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
+ FirstLoad->getChain(), FirstLoad->getBasePtr(),
+ FirstLoad->getPointerInfo(), JointMemOpVT,
+ FirstLoadAlign, MMOFlags);
+ NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad,
+ FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(),
+ JointMemOpVT, FirstInChain->getAlignment(),
+ FirstInChain->getMemOperand()->getFlags());
+ }
- SDValue NewStore = DAG.getStore(
- NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
- FirstInChain->getPointerInfo(), false, false, FirstStoreAlign);
+ // Transfer chain users from old loads to the new load.
+ for (unsigned i = 0; i < NumElem; ++i) {
+ LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
+ SDValue(NewLoad.getNode(), 1));
+ }
- // Transfer chain users from old loads to the new load.
- for (unsigned i = 0; i < NumElem; ++i) {
- LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
- DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
- SDValue(NewLoad.getNode(), 1));
- }
-
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
- if (UseAA) {
- // Replace the all stores with the new store.
- for (unsigned i = 0; i < NumElem; ++i)
- CombineTo(StoreNodes[i].MemNode, NewStore);
- } else {
- // Replace the last store with the new store.
- CombineTo(LatestOp, NewStore);
- // Erase all other stores.
- for (unsigned i = 0; i < NumElem; ++i) {
- // Remove all Store nodes.
- if (StoreNodes[i].MemNode == LatestOp)
- continue;
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- DAG.ReplaceAllUsesOfValueWith(SDValue(St, 0), St->getChain());
- deleteAndRecombine(St);
+ // Replace the all stores with the new store. Recursively remove
+ // corresponding value if its no longer used.
+ for (unsigned i = 0; i < NumElem; ++i) {
+ SDValue Val = StoreNodes[i].MemNode->getOperand(1);
+ CombineTo(StoreNodes[i].MemNode, NewStore);
+ if (Val.getNode()->use_empty())
+ recursivelyDeleteUnusedNodes(Val.getNode());
+ }
+
+ RV = true;
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
+ NumConsecutiveStores -= NumElem;
}
}
-
- return true;
+ return RV;
}
SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
@@ -11824,21 +15133,17 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
std::swap(Lo, Hi);
unsigned Alignment = ST->getAlignment();
- bool isVolatile = ST->isVolatile();
- bool isNonTemporal = ST->isNonTemporal();
+ MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
AAMDNodes AAInfo = ST->getAAInfo();
- SDValue St0 = DAG.getStore(Chain, DL, Lo,
- Ptr, ST->getPointerInfo(),
- isVolatile, isNonTemporal,
- ST->getAlignment(), AAInfo);
+ SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
+ ST->getAlignment(), MMOFlags, AAInfo);
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
DAG.getConstant(4, DL, Ptr.getValueType()));
Alignment = MinAlign(Alignment, 4U);
- SDValue St1 = DAG.getStore(Chain, DL, Hi,
- Ptr, ST->getPointerInfo().getWithOffset(4),
- isVolatile, isNonTemporal,
- Alignment, AAInfo);
+ SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
+ ST->getPointerInfo().getWithOffset(4),
+ Alignment, MMOFlags, AAInfo);
return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
St0, St1);
}
@@ -11857,34 +15162,42 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
// resultant store does not need a higher alignment than the original.
if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
ST->isUnindexed()) {
- unsigned OrigAlign = ST->getAlignment();
EVT SVT = Value.getOperand(0).getValueType();
- unsigned Align = DAG.getDataLayout().getABITypeAlignment(
- SVT.getTypeForEVT(*DAG.getContext()));
- if (Align <= OrigAlign &&
- ((!LegalOperations && !ST->isVolatile()) ||
- TLI.isOperationLegalOrCustom(ISD::STORE, SVT)))
- return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0),
- Ptr, ST->getPointerInfo(), ST->isVolatile(),
- ST->isNonTemporal(), OrigAlign,
- ST->getAAInfo());
+ // If the store is volatile, we only want to change the store type if the
+ // resulting store is legal. Otherwise we might increase the number of
+ // memory accesses. We don't care if the original type was legal or not
+ // as we assume software couldn't rely on the number of accesses of an
+ // illegal type.
+ if (((!LegalOperations && !ST->isVolatile()) ||
+ TLI.isOperationLegal(ISD::STORE, SVT)) &&
+ TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT)) {
+ unsigned OrigAlign = ST->getAlignment();
+ bool Fast = false;
+ if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), SVT,
+ ST->getAddressSpace(), OrigAlign, &Fast) &&
+ Fast) {
+ return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
+ ST->getPointerInfo(), OrigAlign,
+ ST->getMemOperand()->getFlags(), ST->getAAInfo());
+ }
+ }
}
// Turn 'store undef, Ptr' -> nothing.
- if (Value.getOpcode() == ISD::UNDEF && ST->isUnindexed())
+ if (Value.isUndef() && ST->isUnindexed())
return Chain;
// Try to infer better alignment information than the store already has.
if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) {
if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
- if (Align > ST->getAlignment()) {
+ if (Align > ST->getAlignment() && ST->getSrcValueOffset() % Align == 0) {
SDValue NewStore =
- DAG.getTruncStore(Chain, SDLoc(N), Value,
- Ptr, ST->getPointerInfo(), ST->getMemoryVT(),
- ST->isVolatile(), ST->isNonTemporal(), Align,
- ST->getAAInfo());
- if (NewStore.getNode() != N)
- return CombineTo(ST, NewStore, true);
+ DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
+ ST->getMemoryVT(), Align,
+ ST->getMemOperand()->getFlags(), ST->getAAInfo());
+ // NewStore will always be N as we are only refining the alignment
+ assert(NewStore.getNode() == N);
+ (void)NewStore;
}
}
}
@@ -11894,19 +15207,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
if (SDValue NewST = TransformFPLoadStorePair(N))
return NewST;
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
-#ifndef NDEBUG
- if (CombinerAAOnlyFunc.getNumOccurrences() &&
- CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
- UseAA = false;
-#endif
- if (UseAA && ST->isUnindexed()) {
- // FIXME: We should do this even without AA enabled. AA will just allow
- // FindBetterChain to work in more situations. The problem with this is that
- // any combine that expects memory operations to be on consecutive chains
- // first needs to be updated to look for users of the same chain.
-
+ if (ST->isUnindexed()) {
// Walk up chain skipping non-aliasing memory nodes, on this store and any
// adjacent stores.
if (findBetterNeighborChains(ST)) {
@@ -11914,23 +15215,20 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
// manipulation. Return the original node to not do anything else.
return SDValue(ST, 0);
}
+ Chain = ST->getChain();
}
- // Try transforming N to an indexed store.
- if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
- return SDValue(N, 0);
-
// FIXME: is there such a thing as a truncating indexed store?
if (ST->isTruncatingStore() && ST->isUnindexed() &&
- Value.getValueType().isInteger()) {
+ Value.getValueType().isInteger() &&
+ (!isa<ConstantSDNode>(Value) ||
+ !cast<ConstantSDNode>(Value)->isOpaque())) {
// See if we can simplify the input to this truncstore with knowledge that
// only the low bits are being used. For example:
// "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
- SDValue Shorter =
- GetDemandedBits(Value,
- APInt::getLowBitsSet(
- Value.getValueType().getScalarType().getSizeInBits(),
- ST->getMemoryVT().getScalarType().getSizeInBits()));
+ SDValue Shorter = DAG.GetDemandedBits(
+ Value, APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
+ ST->getMemoryVT().getScalarSizeInBits()));
AddToWorklist(Value.getNode());
if (Shorter.getNode())
return DAG.getTruncStore(Chain, SDLoc(N), Shorter,
@@ -11938,11 +15236,18 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
// Otherwise, see if we can simplify the operation with
// SimplifyDemandedBits, which only works if the value has a single use.
- if (SimplifyDemandedBits(Value,
- APInt::getLowBitsSet(
- Value.getValueType().getScalarType().getSizeInBits(),
- ST->getMemoryVT().getScalarType().getSizeInBits())))
+ if (SimplifyDemandedBits(
+ Value,
+ APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
+ ST->getMemoryVT().getScalarSizeInBits()))) {
+ // Re-visit the store if anything changed and the store hasn't been merged
+ // with another node (N is deleted) SimplifyDemandedBits will add Value's
+ // node back to the worklist if necessary, but we also need to re-visit
+ // the Store node itself.
+ if (N->getOpcode() != ISD::DELETED_NODE)
+ AddToWorklist(N);
return SDValue(N, 0);
+ }
}
// If this is a load followed by a store to the same location, then the store
@@ -11958,14 +15263,28 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
}
}
- // If this is a store followed by a store with the same value to the same
- // location, then the store is dead/noop.
if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
- if (ST1->getBasePtr() == Ptr && ST->getMemoryVT() == ST1->getMemoryVT() &&
- ST1->getValue() == Value && ST->isUnindexed() && !ST->isVolatile() &&
- ST1->isUnindexed() && !ST1->isVolatile()) {
- // The store is dead, remove it.
- return Chain;
+ if (ST->isUnindexed() && !ST->isVolatile() && ST1->isUnindexed() &&
+ !ST1->isVolatile() && ST1->getBasePtr() == Ptr &&
+ ST->getMemoryVT() == ST1->getMemoryVT()) {
+ // If this is a store followed by a store with the same value to the same
+ // location, then the store is dead/noop.
+ if (ST1->getValue() == Value) {
+ // The store is dead, remove it.
+ return Chain;
+ }
+
+ // If this is a store who's preceeding store to the same location
+ // and no one other node is chained to that store we can effectively
+ // drop the store. Do not remove stores to undef as they may be used as
+ // data sinks.
+ if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
+ !ST1->getBasePtr().isUndef()) {
+ // ST1 is fully overwritten and can be elided. Combine with it's chain
+ // value.
+ CombineTo(ST1, ST1->getChain());
+ return SDValue();
+ }
}
}
@@ -11979,57 +15298,237 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
Ptr, ST->getMemoryVT(), ST->getMemOperand());
}
- // Only perform this optimization before the types are legal, because we
- // don't want to perform this optimization on every DAGCombine invocation.
- if (!LegalTypes) {
- bool EverChanged = false;
-
- do {
+ // Always perform this optimization before types are legal. If the target
+ // prefers, also try this after legalization to catch stores that were created
+ // by intrinsics or other nodes.
+ if (!LegalTypes || (TLI.mergeStoresAfterLegalization())) {
+ while (true) {
// There can be multiple store sequences on the same chain.
// Keep trying to merge store sequences until we are unable to do so
// or until we merge the last store on the chain.
bool Changed = MergeConsecutiveStores(ST);
- EverChanged |= Changed;
if (!Changed) break;
- } while (ST->getOpcode() != ISD::DELETED_NODE);
-
- if (EverChanged)
- return SDValue(N, 0);
+ // Return N as merge only uses CombineTo and no worklist clean
+ // up is necessary.
+ if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
+ return SDValue(N, 0);
+ }
}
+ // Try transforming N to an indexed store.
+ if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
+ return SDValue(N, 0);
+
// Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
//
// Make sure to do this only after attempting to merge stores in order to
// avoid changing the types of some subset of stores due to visit order,
// preventing their merging.
- if (isa<ConstantFPSDNode>(Value)) {
+ if (isa<ConstantFPSDNode>(ST->getValue())) {
if (SDValue NewSt = replaceStoreOfFPConstant(ST))
return NewSt;
}
+ if (SDValue NewSt = splitMergedValStore(ST))
+ return NewSt;
+
return ReduceLoadOpStoreWidth(N);
}
+/// For the instruction sequence of store below, F and I values
+/// are bundled together as an i64 value before being stored into memory.
+/// Sometimes it is more efficent to generate separate stores for F and I,
+/// which can remove the bitwise instructions or sink them to colder places.
+///
+/// (store (or (zext (bitcast F to i32) to i64),
+/// (shl (zext I to i64), 32)), addr) -->
+/// (store F, addr) and (store I, addr+4)
+///
+/// Similarly, splitting for other merged store can also be beneficial, like:
+/// For pair of {i32, i32}, i64 store --> two i32 stores.
+/// For pair of {i32, i16}, i64 store --> two i32 stores.
+/// For pair of {i16, i16}, i32 store --> two i16 stores.
+/// For pair of {i16, i8}, i32 store --> two i16 stores.
+/// For pair of {i8, i8}, i16 store --> two i8 stores.
+///
+/// We allow each target to determine specifically which kind of splitting is
+/// supported.
+///
+/// The store patterns are commonly seen from the simple code snippet below
+/// if only std::make_pair(...) is sroa transformed before inlined into hoo.
+/// void goo(const std::pair<int, float> &);
+/// hoo() {
+/// ...
+/// goo(std::make_pair(tmp, ftmp));
+/// ...
+/// }
+///
+SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
+ if (OptLevel == CodeGenOpt::None)
+ return SDValue();
+
+ SDValue Val = ST->getValue();
+ SDLoc DL(ST);
+
+ // Match OR operand.
+ if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
+ return SDValue();
+
+ // Match SHL operand and get Lower and Higher parts of Val.
+ SDValue Op1 = Val.getOperand(0);
+ SDValue Op2 = Val.getOperand(1);
+ SDValue Lo, Hi;
+ if (Op1.getOpcode() != ISD::SHL) {
+ std::swap(Op1, Op2);
+ if (Op1.getOpcode() != ISD::SHL)
+ return SDValue();
+ }
+ Lo = Op2;
+ Hi = Op1.getOperand(0);
+ if (!Op1.hasOneUse())
+ return SDValue();
+
+ // Match shift amount to HalfValBitSize.
+ unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
+ ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
+ if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
+ return SDValue();
+
+ // Lo and Hi are zero-extended from int with size less equal than 32
+ // to i64.
+ if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
+ !Lo.getOperand(0).getValueType().isScalarInteger() ||
+ Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
+ Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
+ !Hi.getOperand(0).getValueType().isScalarInteger() ||
+ Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
+ return SDValue();
+
+ // Use the EVT of low and high parts before bitcast as the input
+ // of target query.
+ EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
+ ? Lo.getOperand(0).getValueType()
+ : Lo.getValueType();
+ EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
+ ? Hi.getOperand(0).getValueType()
+ : Hi.getValueType();
+ if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
+ return SDValue();
+
+ // Start to split store.
+ unsigned Alignment = ST->getAlignment();
+ MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
+ AAMDNodes AAInfo = ST->getAAInfo();
+
+ // Change the sizes of Lo and Hi's value types to HalfValBitSize.
+ EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
+ Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
+ Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
+
+ SDValue Chain = ST->getChain();
+ SDValue Ptr = ST->getBasePtr();
+ // Lower value store.
+ SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
+ ST->getAlignment(), MMOFlags, AAInfo);
+ Ptr =
+ DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
+ DAG.getConstant(HalfValBitSize / 8, DL, Ptr.getValueType()));
+ // Higher value store.
+ SDValue St1 =
+ DAG.getStore(St0, DL, Hi, Ptr,
+ ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
+ Alignment / 2, MMOFlags, AAInfo);
+ return St1;
+}
+
+/// Convert a disguised subvector insertion into a shuffle:
+/// insert_vector_elt V, (bitcast X from vector type), IdxC -->
+/// bitcast(shuffle (bitcast V), (extended X), Mask)
+/// Note: We do not use an insert_subvector node because that requires a legal
+/// subvector type.
+SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
+ SDValue InsertVal = N->getOperand(1);
+ if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
+ !InsertVal.getOperand(0).getValueType().isVector())
+ return SDValue();
+
+ SDValue SubVec = InsertVal.getOperand(0);
+ SDValue DestVec = N->getOperand(0);
+ EVT SubVecVT = SubVec.getValueType();
+ EVT VT = DestVec.getValueType();
+ unsigned NumSrcElts = SubVecVT.getVectorNumElements();
+ unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
+ unsigned NumMaskVals = ExtendRatio * NumSrcElts;
+
+ // Step 1: Create a shuffle mask that implements this insert operation. The
+ // vector that we are inserting into will be operand 0 of the shuffle, so
+ // those elements are just 'i'. The inserted subvector is in the first
+ // positions of operand 1 of the shuffle. Example:
+ // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
+ SmallVector<int, 16> Mask(NumMaskVals);
+ for (unsigned i = 0; i != NumMaskVals; ++i) {
+ if (i / NumSrcElts == InsIndex)
+ Mask[i] = (i % NumSrcElts) + NumMaskVals;
+ else
+ Mask[i] = i;
+ }
+
+ // Bail out if the target can not handle the shuffle we want to create.
+ EVT SubVecEltVT = SubVecVT.getVectorElementType();
+ EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
+ if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
+ return SDValue();
+
+ // Step 2: Create a wide vector from the inserted source vector by appending
+ // undefined elements. This is the same size as our destination vector.
+ SDLoc DL(N);
+ SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
+ ConcatOps[0] = SubVec;
+ SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
+
+ // Step 3: Shuffle in the padded subvector.
+ SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
+ SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
+ AddToWorklist(PaddedSubV.getNode());
+ AddToWorklist(DestVecBC.getNode());
+ AddToWorklist(Shuf.getNode());
+ return DAG.getBitcast(VT, Shuf);
+}
+
SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
SDValue InVec = N->getOperand(0);
SDValue InVal = N->getOperand(1);
SDValue EltNo = N->getOperand(2);
- SDLoc dl(N);
+ SDLoc DL(N);
// If the inserted element is an UNDEF, just use the input vector.
- if (InVal.getOpcode() == ISD::UNDEF)
+ if (InVal.isUndef())
return InVec;
EVT VT = InVec.getValueType();
+ unsigned NumElts = VT.getVectorNumElements();
- // If we can't generate a legal BUILD_VECTOR, exit
- if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
- return SDValue();
+ // Remove redundant insertions:
+ // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
+ if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
+ return InVec;
- // Check that we know which element is being inserted
- if (!isa<ConstantSDNode>(EltNo))
+ auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
+ if (!IndexC) {
+ // If this is variable insert to undef vector, it might be better to splat:
+ // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
+ if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
+ SmallVector<SDValue, 8> Ops(NumElts, InVal);
+ return DAG.getBuildVector(VT, DL, Ops);
+ }
return SDValue();
- unsigned Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
+ }
+
+ // We must know which element is being inserted for folds below here.
+ unsigned Elt = IndexC->getZExtValue();
+ if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
+ return Shuf;
// Canonicalize insert_vector_elt dag nodes.
// Example:
@@ -12040,11 +15539,10 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
// do this only if indices are both constants and Idx1 < Idx0.
if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
&& isa<ConstantSDNode>(InVec.getOperand(2))) {
- unsigned OtherElt =
- cast<ConstantSDNode>(InVec.getOperand(2))->getZExtValue();
+ unsigned OtherElt = InVec.getConstantOperandVal(2);
if (Elt < OtherElt) {
// Swap nodes.
- SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(N), VT,
+ SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
InVec.getOperand(0), InVal, EltNo);
AddToWorklist(NewOp.getNode());
return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
@@ -12052,6 +15550,10 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
}
}
+ // If we can't generate a legal BUILD_VECTOR, exit
+ if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
+ return SDValue();
+
// Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
// be converted to a BUILD_VECTOR). Fill in the Ops vector with the
// vector elements.
@@ -12061,31 +15563,30 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
Ops.append(InVec.getNode()->op_begin(),
InVec.getNode()->op_end());
- } else if (InVec.getOpcode() == ISD::UNDEF) {
- unsigned NElts = VT.getVectorNumElements();
- Ops.append(NElts, DAG.getUNDEF(InVal.getValueType()));
+ } else if (InVec.isUndef()) {
+ Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType()));
} else {
return SDValue();
}
+ assert(Ops.size() == NumElts && "Unexpected vector size");
// Insert the element
if (Elt < Ops.size()) {
// All the operands of BUILD_VECTOR must have the same type;
// we enforce that here.
EVT OpVT = Ops[0].getValueType();
- if (InVal.getValueType() != OpVT)
- InVal = OpVT.bitsGT(InVal.getValueType()) ?
- DAG.getNode(ISD::ANY_EXTEND, dl, OpVT, InVal) :
- DAG.getNode(ISD::TRUNCATE, dl, OpVT, InVal);
- Ops[Elt] = InVal;
+ Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
}
// Return the new vector
- return DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ops);
+ return DAG.getBuildVector(VT, DL, Ops);
}
-SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
- SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad) {
+SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
+ SDValue EltNo,
+ LoadSDNode *OriginalLoad) {
+ assert(!OriginalLoad->isVolatile());
+
EVT ResultVT = EVE->getValueType(0);
EVT VecEltVT = InVecVT.getVectorElementType();
unsigned Align = OriginalLoad->getAlignment();
@@ -12095,6 +15596,11 @@ SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
return SDValue();
+ ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ?
+ ISD::NON_EXTLOAD : ISD::EXTLOAD;
+ if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
+ return SDValue();
+
Align = NewAlign;
SDValue NewPtr = OriginalLoad->getBasePtr();
@@ -12131,21 +15637,20 @@ SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
VecEltVT)
? ISD::ZEXTLOAD
: ISD::EXTLOAD;
- Load = DAG.getExtLoad(
- ExtType, SDLoc(EVE), ResultVT, OriginalLoad->getChain(), NewPtr, MPI,
- VecEltVT, OriginalLoad->isVolatile(), OriginalLoad->isNonTemporal(),
- OriginalLoad->isInvariant(), Align, OriginalLoad->getAAInfo());
+ Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT,
+ OriginalLoad->getChain(), NewPtr, MPI, VecEltVT,
+ Align, OriginalLoad->getMemOperand()->getFlags(),
+ OriginalLoad->getAAInfo());
Chain = Load.getValue(1);
} else {
- Load = DAG.getLoad(
- VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, MPI,
- OriginalLoad->isVolatile(), OriginalLoad->isNonTemporal(),
- OriginalLoad->isInvariant(), Align, OriginalLoad->getAAInfo());
+ Load = DAG.getLoad(VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr,
+ MPI, Align, OriginalLoad->getMemOperand()->getFlags(),
+ OriginalLoad->getAAInfo());
Chain = Load.getValue(1);
if (ResultVT.bitsLT(VecEltVT))
Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load);
else
- Load = DAG.getNode(ISD::BITCAST, SDLoc(EVE), ResultVT, Load);
+ Load = DAG.getBitcast(ResultVT, Load);
}
WorklistRemover DeadNodes(*this);
SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) };
@@ -12161,74 +15666,162 @@ SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
return SDValue(EVE, 0);
}
+/// Transform a vector binary operation into a scalar binary operation by moving
+/// the math/logic after an extract element of a vector.
+static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
+ bool LegalOperations) {
+ SDValue Vec = ExtElt->getOperand(0);
+ SDValue Index = ExtElt->getOperand(1);
+ auto *IndexC = dyn_cast<ConstantSDNode>(Index);
+ if (!IndexC || !ISD::isBinaryOp(Vec.getNode()) || !Vec.hasOneUse())
+ return SDValue();
+
+ // Targets may want to avoid this to prevent an expensive register transfer.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.shouldScalarizeBinop(Vec))
+ return SDValue();
+
+ // Extracting an element of a vector constant is constant-folded, so this
+ // transform is just replacing a vector op with a scalar op while moving the
+ // extract.
+ SDValue Op0 = Vec.getOperand(0);
+ SDValue Op1 = Vec.getOperand(1);
+ if (isAnyConstantBuildVector(Op0, true) ||
+ isAnyConstantBuildVector(Op1, true)) {
+ // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
+ // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
+ SDLoc DL(ExtElt);
+ EVT VT = ExtElt->getValueType(0);
+ SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
+ SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
+ return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
+ }
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
- // (vextract (scalar_to_vector val, 0) -> val
- SDValue InVec = N->getOperand(0);
- EVT VT = InVec.getValueType();
- EVT NVT = N->getValueType(0);
+ SDValue VecOp = N->getOperand(0);
+ SDValue Index = N->getOperand(1);
+ EVT ScalarVT = N->getValueType(0);
+ EVT VecVT = VecOp.getValueType();
+ if (VecOp.isUndef())
+ return DAG.getUNDEF(ScalarVT);
+
+ // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
+ //
+ // This only really matters if the index is non-constant since other combines
+ // on the constant elements already work.
+ SDLoc DL(N);
+ if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
+ Index == VecOp.getOperand(2)) {
+ SDValue Elt = VecOp.getOperand(1);
+ return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
+ }
- if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR) {
+ // (vextract (scalar_to_vector val, 0) -> val
+ if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
// Check if the result type doesn't match the inserted element type. A
// SCALAR_TO_VECTOR may truncate the inserted element and the
// EXTRACT_VECTOR_ELT may widen the extracted vector.
- SDValue InOp = InVec.getOperand(0);
- if (InOp.getValueType() != NVT) {
- assert(InOp.getValueType().isInteger() && NVT.isInteger());
- return DAG.getSExtOrTrunc(InOp, SDLoc(InVec), NVT);
+ SDValue InOp = VecOp.getOperand(0);
+ if (InOp.getValueType() != ScalarVT) {
+ assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
+ return DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
}
return InOp;
}
- SDValue EltNo = N->getOperand(1);
- ConstantSDNode *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo);
+ // extract_vector_elt of out-of-bounds element -> UNDEF
+ auto *IndexC = dyn_cast<ConstantSDNode>(Index);
+ unsigned NumElts = VecVT.getVectorNumElements();
+ if (IndexC && IndexC->getAPIntValue().uge(NumElts))
+ return DAG.getUNDEF(ScalarVT);
// extract_vector_elt (build_vector x, y), 1 -> y
- if (ConstEltNo &&
- InVec.getOpcode() == ISD::BUILD_VECTOR &&
- TLI.isTypeLegal(VT) &&
- (InVec.hasOneUse() ||
- TLI.aggressivelyPreferBuildVectorSources(VT))) {
- SDValue Elt = InVec.getOperand(ConstEltNo->getZExtValue());
+ if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR &&
+ TLI.isTypeLegal(VecVT) &&
+ (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
+ SDValue Elt = VecOp.getOperand(IndexC->getZExtValue());
EVT InEltVT = Elt.getValueType();
// Sometimes build_vector's scalar input types do not match result type.
- if (NVT == InEltVT)
+ if (ScalarVT == InEltVT)
return Elt;
// TODO: It may be useful to truncate if free if the build_vector implicitly
// converts.
}
+ // TODO: These transforms should not require the 'hasOneUse' restriction, but
+ // there are regressions on multiple targets without it. We can end up with a
+ // mess of scalar and vector code if we reduce only part of the DAG to scalar.
+ if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
+ VecOp.hasOneUse()) {
+ // The vector index of the LSBs of the source depend on the endian-ness.
+ bool IsLE = DAG.getDataLayout().isLittleEndian();
+ unsigned ExtractIndex = IndexC->getZExtValue();
+ // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
+ unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
+ SDValue BCSrc = VecOp.getOperand(0);
+ if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
+ return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
+
+ if (LegalTypes && BCSrc.getValueType().isInteger() &&
+ BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
+ // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
+ // trunc i64 X to i32
+ SDValue X = BCSrc.getOperand(0);
+ assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
+ "Extract element and scalar to vector can't change element type "
+ "from FP to integer.");
+ unsigned XBitWidth = X.getValueSizeInBits();
+ unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
+ BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
+
+ // An extract element return value type can be wider than its vector
+ // operand element type. In that case, the high bits are undefined, so
+ // it's possible that we may need to extend rather than truncate.
+ if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
+ assert(XBitWidth % VecEltBitWidth == 0 &&
+ "Scalar bitwidth must be a multiple of vector element bitwidth");
+ return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
+ }
+ }
+ }
+
+ if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
+ return BO;
+
// Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
// We only perform this optimization before the op legalization phase because
// we may introduce new vector instructions which are not backed by TD
// patterns. For example on AVX, extracting elements from a wide vector
// without using extract_subvector. However, if we can find an underlying
// scalar value, then we can always use that.
- if (ConstEltNo && InVec.getOpcode() == ISD::VECTOR_SHUFFLE) {
- int NumElem = VT.getVectorNumElements();
- ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(InVec);
+ if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
+ auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
// Find the new index to extract from.
- int OrigElt = SVOp->getMaskElt(ConstEltNo->getZExtValue());
+ int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
// Extracting an undef index is undef.
if (OrigElt == -1)
- return DAG.getUNDEF(NVT);
+ return DAG.getUNDEF(ScalarVT);
// Select the right vector half to extract from.
SDValue SVInVec;
- if (OrigElt < NumElem) {
- SVInVec = InVec->getOperand(0);
+ if (OrigElt < (int)NumElts) {
+ SVInVec = VecOp.getOperand(0);
} else {
- SVInVec = InVec->getOperand(1);
- OrigElt -= NumElem;
+ SVInVec = VecOp.getOperand(1);
+ OrigElt -= NumElts;
}
if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
SDValue InOp = SVInVec.getOperand(OrigElt);
- if (InOp.getValueType() != NVT) {
- assert(InOp.getValueType().isInteger() && NVT.isInteger());
- InOp = DAG.getSExtOrTrunc(InOp, SDLoc(SVInVec), NVT);
+ if (InOp.getValueType() != ScalarVT) {
+ assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
+ InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
}
return InOp;
@@ -12237,115 +15830,133 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
// FIXME: We should handle recursing on other vector shuffles and
// scalar_to_vector here as well.
- if (!LegalOperations) {
+ if (!LegalOperations ||
+ // FIXME: Should really be just isOperationLegalOrCustom.
+ TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
+ TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), NVT, SVInVec,
- DAG.getConstant(OrigElt, SDLoc(SVOp), IndexTy));
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
+ DAG.getConstant(OrigElt, DL, IndexTy));
}
}
- bool BCNumEltsChanged = false;
- EVT ExtVT = VT.getVectorElementType();
- EVT LVT = ExtVT;
+ // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
+ // simplify it based on the (valid) extraction indices.
+ if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
+ return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ Use->getOperand(0) == VecOp &&
+ isa<ConstantSDNode>(Use->getOperand(1));
+ })) {
+ APInt DemandedElts = APInt::getNullValue(NumElts);
+ for (SDNode *Use : VecOp->uses()) {
+ auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
+ if (CstElt->getAPIntValue().ult(NumElts))
+ DemandedElts.setBit(CstElt->getZExtValue());
+ }
+ if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
+ // We simplified the vector operand of this extract element. If this
+ // extract is not dead, visit it again so it is folded properly.
+ if (N->getOpcode() != ISD::DELETED_NODE)
+ AddToWorklist(N);
+ return SDValue(N, 0);
+ }
+ }
+ // Everything under here is trying to match an extract of a loaded value.
// If the result of load has to be truncated, then it's not necessarily
// profitable.
- if (NVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, NVT))
+ bool BCNumEltsChanged = false;
+ EVT ExtVT = VecVT.getVectorElementType();
+ EVT LVT = ExtVT;
+ if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
return SDValue();
- if (InVec.getOpcode() == ISD::BITCAST) {
+ if (VecOp.getOpcode() == ISD::BITCAST) {
// Don't duplicate a load with other uses.
- if (!InVec.hasOneUse())
+ if (!VecOp.hasOneUse())
return SDValue();
- EVT BCVT = InVec.getOperand(0).getValueType();
+ EVT BCVT = VecOp.getOperand(0).getValueType();
if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
return SDValue();
- if (VT.getVectorNumElements() != BCVT.getVectorNumElements())
+ if (NumElts != BCVT.getVectorNumElements())
BCNumEltsChanged = true;
- InVec = InVec.getOperand(0);
+ VecOp = VecOp.getOperand(0);
ExtVT = BCVT.getVectorElementType();
}
- // (vextract (vN[if]M load $addr), i) -> ([if]M load $addr + i * size)
- if (!LegalOperations && !ConstEltNo && InVec.hasOneUse() &&
- ISD::isNormalLoad(InVec.getNode()) &&
- !N->getOperand(1)->hasPredecessor(InVec.getNode())) {
- SDValue Index = N->getOperand(1);
- if (LoadSDNode *OrigLoad = dyn_cast<LoadSDNode>(InVec))
- return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, Index,
- OrigLoad);
+ // extract (vector load $addr), i --> load $addr + i * size
+ if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
+ ISD::isNormalLoad(VecOp.getNode()) &&
+ !Index->hasPredecessor(VecOp.getNode())) {
+ auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
+ if (VecLoad && !VecLoad->isVolatile())
+ return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
}
// Perform only after legalization to ensure build_vector / vector_shuffle
// optimizations have already been done.
- if (!LegalOperations) return SDValue();
+ if (!LegalOperations || !IndexC)
+ return SDValue();
// (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
// (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
// (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
+ int Elt = IndexC->getZExtValue();
+ LoadSDNode *LN0 = nullptr;
+ if (ISD::isNormalLoad(VecOp.getNode())) {
+ LN0 = cast<LoadSDNode>(VecOp);
+ } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
+ VecOp.getOperand(0).getValueType() == ExtVT &&
+ ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
+ // Don't duplicate a load with other uses.
+ if (!VecOp.hasOneUse())
+ return SDValue();
- if (ConstEltNo) {
- int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
+ LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
+ }
+ if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
+ // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
+ // =>
+ // (load $addr+1*size)
- LoadSDNode *LN0 = nullptr;
- const ShuffleVectorSDNode *SVN = nullptr;
- if (ISD::isNormalLoad(InVec.getNode())) {
- LN0 = cast<LoadSDNode>(InVec);
- } else if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR &&
- InVec.getOperand(0).getValueType() == ExtVT &&
- ISD::isNormalLoad(InVec.getOperand(0).getNode())) {
- // Don't duplicate a load with other uses.
- if (!InVec.hasOneUse())
- return SDValue();
+ // Don't duplicate a load with other uses.
+ if (!VecOp.hasOneUse())
+ return SDValue();
+
+ // If the bit convert changed the number of elements, it is unsafe
+ // to examine the mask.
+ if (BCNumEltsChanged)
+ return SDValue();
- LN0 = cast<LoadSDNode>(InVec.getOperand(0));
- } else if ((SVN = dyn_cast<ShuffleVectorSDNode>(InVec))) {
- // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
- // =>
- // (load $addr+1*size)
+ // Select the input vector, guarding against out of range extract vector.
+ int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
+ VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
+ if (VecOp.getOpcode() == ISD::BITCAST) {
// Don't duplicate a load with other uses.
- if (!InVec.hasOneUse())
- return SDValue();
-
- // If the bit convert changed the number of elements, it is unsafe
- // to examine the mask.
- if (BCNumEltsChanged)
+ if (!VecOp.hasOneUse())
return SDValue();
- // Select the input vector, guarding against out of range extract vector.
- unsigned NumElems = VT.getVectorNumElements();
- int Idx = (Elt > (int)NumElems) ? -1 : SVN->getMaskElt(Elt);
- InVec = (Idx < (int)NumElems) ? InVec.getOperand(0) : InVec.getOperand(1);
-
- if (InVec.getOpcode() == ISD::BITCAST) {
- // Don't duplicate a load with other uses.
- if (!InVec.hasOneUse())
- return SDValue();
-
- InVec = InVec.getOperand(0);
- }
- if (ISD::isNormalLoad(InVec.getNode())) {
- LN0 = cast<LoadSDNode>(InVec);
- Elt = (Idx < (int)NumElems) ? Idx : Idx - (int)NumElems;
- EltNo = DAG.getConstant(Elt, SDLoc(EltNo), EltNo.getValueType());
- }
+ VecOp = VecOp.getOperand(0);
}
+ if (ISD::isNormalLoad(VecOp.getNode())) {
+ LN0 = cast<LoadSDNode>(VecOp);
+ Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
+ Index = DAG.getConstant(Elt, DL, Index.getValueType());
+ }
+ }
- // Make sure we found a non-volatile load and the extractelement is
- // the only use.
- if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile())
- return SDValue();
-
- // If Idx was -1 above, Elt is going to be -1, so just return undef.
- if (Elt == -1)
- return DAG.getUNDEF(LVT);
+ // Make sure we found a non-volatile load and the extractelement is
+ // the only use.
+ if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile())
+ return SDValue();
- return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, EltNo, LN0);
- }
+ // If Idx was -1 above, Elt is going to be -1, so just return undef.
+ if (Elt == -1)
+ return DAG.getUNDEF(LVT);
- return SDValue();
+ return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
}
// Simplify (build_vec (ext )) to (bitcast (build_vec ))
@@ -12360,7 +15971,7 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
return SDValue();
unsigned NumInScalars = N->getNumOperands();
- SDLoc dl(N);
+ SDLoc DL(N);
EVT VT = N->getValueType(0);
// Check to see if this is a BUILD_VECTOR of a bunch of values
@@ -12374,7 +15985,7 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
for (unsigned i = 0; i != NumInScalars; ++i) {
SDValue In = N->getOperand(i);
// Ignore undef inputs.
- if (In.getOpcode() == ISD::UNDEF) continue;
+ if (In.isUndef()) continue;
bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
@@ -12419,7 +16030,7 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
assert(ElemRatio > 1 && "Invalid element size ratio");
SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
- DAG.getConstant(0, SDLoc(N), SourceType);
+ DAG.getConstant(0, DL, SourceType);
unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
@@ -12429,9 +16040,9 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
SDValue Cast = N->getOperand(i);
assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
Cast.getOpcode() == ISD::ZERO_EXTEND ||
- Cast.getOpcode() == ISD::UNDEF) && "Invalid cast opcode");
+ Cast.isUndef()) && "Invalid cast opcode");
SDValue In;
- if (Cast.getOpcode() == ISD::UNDEF)
+ if (Cast.isUndef())
In = DAG.getUNDEF(SourceType);
else
In = Cast->getOperand(0);
@@ -12447,259 +16058,550 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
"Invalid vector size");
// Check if the new vector type is legal.
- if (!isTypeLegal(VecVT)) return SDValue();
+ if (!isTypeLegal(VecVT) ||
+ (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
+ TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
+ return SDValue();
// Make the new BUILD_VECTOR.
- SDValue BV = DAG.getNode(ISD::BUILD_VECTOR, dl, VecVT, Ops);
+ SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
// The new BUILD_VECTOR node has the potential to be further optimized.
AddToWorklist(BV.getNode());
// Bitcast to the desired type.
- return DAG.getNode(ISD::BITCAST, dl, VT, BV);
+ return DAG.getBitcast(VT, BV);
}
-SDValue DAGCombiner::reduceBuildVecConvertToConvertBuildVec(SDNode *N) {
+SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
+ ArrayRef<int> VectorMask,
+ SDValue VecIn1, SDValue VecIn2,
+ unsigned LeftIdx) {
+ MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
+ SDValue ZeroIdx = DAG.getConstant(0, DL, IdxTy);
+
EVT VT = N->getValueType(0);
+ EVT InVT1 = VecIn1.getValueType();
+ EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
+
+ unsigned Vec2Offset = 0;
+ unsigned NumElems = VT.getVectorNumElements();
+ unsigned ShuffleNumElems = NumElems;
+
+ // In case both the input vectors are extracted from same base
+ // vector we do not need extra addend (Vec2Offset) while
+ // computing shuffle mask.
+ if (!VecIn2 || !(VecIn1.getOpcode() == ISD::EXTRACT_SUBVECTOR) ||
+ !(VecIn2.getOpcode() == ISD::EXTRACT_SUBVECTOR) ||
+ !(VecIn1.getOperand(0) == VecIn2.getOperand(0)))
+ Vec2Offset = InVT1.getVectorNumElements();
+
+ // We can't generate a shuffle node with mismatched input and output types.
+ // Try to make the types match the type of the output.
+ if (InVT1 != VT || InVT2 != VT) {
+ if ((VT.getSizeInBits() % InVT1.getSizeInBits() == 0) && InVT1 == InVT2) {
+ // If the output vector length is a multiple of both input lengths,
+ // we can concatenate them and pad the rest with undefs.
+ unsigned NumConcats = VT.getSizeInBits() / InVT1.getSizeInBits();
+ assert(NumConcats >= 2 && "Concat needs at least two inputs!");
+ SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
+ ConcatOps[0] = VecIn1;
+ ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
+ VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
+ VecIn2 = SDValue();
+ } else if (InVT1.getSizeInBits() == VT.getSizeInBits() * 2) {
+ if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
+ return SDValue();
- unsigned NumInScalars = N->getNumOperands();
- SDLoc dl(N);
+ if (!VecIn2.getNode()) {
+ // If we only have one input vector, and it's twice the size of the
+ // output, split it in two.
+ VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
+ DAG.getConstant(NumElems, DL, IdxTy));
+ VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
+ // Since we now have shorter input vectors, adjust the offset of the
+ // second vector's start.
+ Vec2Offset = NumElems;
+ } else if (InVT2.getSizeInBits() <= InVT1.getSizeInBits()) {
+ // VecIn1 is wider than the output, and we have another, possibly
+ // smaller input. Pad the smaller input with undefs, shuffle at the
+ // input vector width, and extract the output.
+ // The shuffle type is different than VT, so check legality again.
+ if (LegalOperations &&
+ !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
+ return SDValue();
- EVT SrcVT = MVT::Other;
- unsigned Opcode = ISD::DELETED_NODE;
- unsigned NumDefs = 0;
+ // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
+ // lower it back into a BUILD_VECTOR. So if the inserted type is
+ // illegal, don't even try.
+ if (InVT1 != InVT2) {
+ if (!TLI.isTypeLegal(InVT2))
+ return SDValue();
+ VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
+ DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
+ }
+ ShuffleNumElems = NumElems * 2;
+ } else {
+ // Both VecIn1 and VecIn2 are wider than the output, and VecIn2 is wider
+ // than VecIn1. We can't handle this for now - this case will disappear
+ // when we start sorting the vectors by type.
+ return SDValue();
+ }
+ } else if (InVT2.getSizeInBits() * 2 == VT.getSizeInBits() &&
+ InVT1.getSizeInBits() == VT.getSizeInBits()) {
+ SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
+ ConcatOps[0] = VecIn2;
+ VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
+ } else {
+ // TODO: Support cases where the length mismatch isn't exactly by a
+ // factor of 2.
+ // TODO: Move this check upwards, so that if we have bad type
+ // mismatches, we don't create any DAG nodes.
+ return SDValue();
+ }
+ }
- for (unsigned i = 0; i != NumInScalars; ++i) {
- SDValue In = N->getOperand(i);
- unsigned Opc = In.getOpcode();
+ // Initialize mask to undef.
+ SmallVector<int, 8> Mask(ShuffleNumElems, -1);
- if (Opc == ISD::UNDEF)
+ // Only need to run up to the number of elements actually used, not the
+ // total number of elements in the shuffle - if we are shuffling a wider
+ // vector, the high lanes should be set to undef.
+ for (unsigned i = 0; i != NumElems; ++i) {
+ if (VectorMask[i] <= 0)
continue;
- // If all scalar values are floats and converted from integers.
- if (Opcode == ISD::DELETED_NODE &&
- (Opc == ISD::UINT_TO_FP || Opc == ISD::SINT_TO_FP)) {
- Opcode = Opc;
+ unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
+ if (VectorMask[i] == (int)LeftIdx) {
+ Mask[i] = ExtIndex;
+ } else if (VectorMask[i] == (int)LeftIdx + 1) {
+ Mask[i] = Vec2Offset + ExtIndex;
}
+ }
- if (Opc != Opcode)
- return SDValue();
+ // The type the input vectors may have changed above.
+ InVT1 = VecIn1.getValueType();
- EVT InVT = In.getOperand(0).getValueType();
+ // If we already have a VecIn2, it should have the same type as VecIn1.
+ // If we don't, get an undef/zero vector of the appropriate type.
+ VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
+ assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
- // If all scalar values are typed differently, bail out. It's chosen to
- // simplify BUILD_VECTOR of integer types.
- if (SrcVT == MVT::Other)
- SrcVT = InVT;
- if (SrcVT != InVT)
- return SDValue();
- NumDefs++;
- }
+ SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
+ if (ShuffleNumElems > NumElems)
+ Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
- // If the vector has just one element defined, it's not worth to fold it into
- // a vectorized one.
- if (NumDefs < 2)
- return SDValue();
+ return Shuffle;
+}
- assert((Opcode == ISD::UINT_TO_FP || Opcode == ISD::SINT_TO_FP)
- && "Should only handle conversion from integer to float.");
- assert(SrcVT != MVT::Other && "Cannot determine source type!");
+static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
+ assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
- EVT NVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumInScalars);
+ // First, determine where the build vector is not undef.
+ // TODO: We could extend this to handle zero elements as well as undefs.
+ int NumBVOps = BV->getNumOperands();
+ int ZextElt = -1;
+ for (int i = 0; i != NumBVOps; ++i) {
+ SDValue Op = BV->getOperand(i);
+ if (Op.isUndef())
+ continue;
+ if (ZextElt == -1)
+ ZextElt = i;
+ else
+ return SDValue();
+ }
+ // Bail out if there's no non-undef element.
+ if (ZextElt == -1)
+ return SDValue();
- if (!TLI.isOperationLegalOrCustom(Opcode, NVT))
+ // The build vector contains some number of undef elements and exactly
+ // one other element. That other element must be a zero-extended scalar
+ // extracted from a vector at a constant index to turn this into a shuffle.
+ // Also, require that the build vector does not implicitly truncate/extend
+ // its elements.
+ // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
+ EVT VT = BV->getValueType(0);
+ SDValue Zext = BV->getOperand(ZextElt);
+ if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
+ Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
+ Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
return SDValue();
- // Just because the floating-point vector type is legal does not necessarily
- // mean that the corresponding integer vector type is.
- if (!isTypeLegal(NVT))
+ // The zero-extend must be a multiple of the source size, and we must be
+ // building a vector of the same size as the source of the extract element.
+ SDValue Extract = Zext.getOperand(0);
+ unsigned DestSize = Zext.getValueSizeInBits();
+ unsigned SrcSize = Extract.getValueSizeInBits();
+ if (DestSize % SrcSize != 0 ||
+ Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
return SDValue();
- SmallVector<SDValue, 8> Opnds;
- for (unsigned i = 0; i != NumInScalars; ++i) {
- SDValue In = N->getOperand(i);
+ // Create a shuffle mask that will combine the extracted element with zeros
+ // and undefs.
+ int ZextRatio = DestSize / SrcSize;
+ int NumMaskElts = NumBVOps * ZextRatio;
+ SmallVector<int, 32> ShufMask(NumMaskElts, -1);
+ for (int i = 0; i != NumMaskElts; ++i) {
+ if (i / ZextRatio == ZextElt) {
+ // The low bits of the (potentially translated) extracted element map to
+ // the source vector. The high bits map to zero. We will use a zero vector
+ // as the 2nd source operand of the shuffle, so use the 1st element of
+ // that vector (mask value is number-of-elements) for the high bits.
+ if (i % ZextRatio == 0)
+ ShufMask[i] = Extract.getConstantOperandVal(1);
+ else
+ ShufMask[i] = NumMaskElts;
+ }
- if (In.getOpcode() == ISD::UNDEF)
- Opnds.push_back(DAG.getUNDEF(SrcVT));
- else
- Opnds.push_back(In.getOperand(0));
+ // Undef elements of the build vector remain undef because we initialize
+ // the shuffle mask with -1.
}
- SDValue BV = DAG.getNode(ISD::BUILD_VECTOR, dl, NVT, Opnds);
- AddToWorklist(BV.getNode());
- return DAG.getNode(Opcode, dl, VT, BV);
+ // Turn this into a shuffle with zero if that's legal.
+ EVT VecVT = Extract.getOperand(0).getValueType();
+ if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(ShufMask, VecVT))
+ return SDValue();
+
+ // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
+ // bitcast (shuffle V, ZeroVec, VectorMask)
+ SDLoc DL(BV);
+ SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
+ SDValue Shuf = DAG.getVectorShuffle(VecVT, DL, Extract.getOperand(0), ZeroVec,
+ ShufMask);
+ return DAG.getBitcast(VT, Shuf);
}
-SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
- unsigned NumInScalars = N->getNumOperands();
- SDLoc dl(N);
+// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
+// operations. If the types of the vectors we're extracting from allow it,
+// turn this into a vector_shuffle node.
+SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
+ SDLoc DL(N);
EVT VT = N->getValueType(0);
- // A vector built entirely of undefs is undef.
- if (ISD::allOperandsUndef(N))
- return DAG.getUNDEF(VT);
-
- if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
- return V;
-
- if (SDValue V = reduceBuildVecConvertToConvertBuildVec(N))
- return V;
-
- // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
- // operations. If so, and if the EXTRACT_VECTOR_ELT vector inputs come from
- // at most two distinct vectors, turn this into a shuffle node.
-
// Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
if (!isTypeLegal(VT))
return SDValue();
+ if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
+ return V;
+
// May only combine to shuffle after legalize if shuffle is legal.
if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
return SDValue();
- SDValue VecIn1, VecIn2;
bool UsesZeroVector = false;
- for (unsigned i = 0; i != NumInScalars; ++i) {
+ unsigned NumElems = N->getNumOperands();
+
+ // Record, for each element of the newly built vector, which input vector
+ // that element comes from. -1 stands for undef, 0 for the zero vector,
+ // and positive values for the input vectors.
+ // VectorMask maps each element to its vector number, and VecIn maps vector
+ // numbers to their initial SDValues.
+
+ SmallVector<int, 8> VectorMask(NumElems, -1);
+ SmallVector<SDValue, 8> VecIn;
+ VecIn.push_back(SDValue());
+
+ for (unsigned i = 0; i != NumElems; ++i) {
SDValue Op = N->getOperand(i);
- // Ignore undef inputs.
- if (Op.getOpcode() == ISD::UNDEF) continue;
- // See if we can combine this build_vector into a blend with a zero vector.
- if (!VecIn2.getNode() && (isNullConstant(Op) || isNullFPConstant(Op))) {
+ if (Op.isUndef())
+ continue;
+
+ // See if we can use a blend with a zero vector.
+ // TODO: Should we generalize this to a blend with an arbitrary constant
+ // vector?
+ if (isNullConstant(Op) || isNullFPConstant(Op)) {
UsesZeroVector = true;
+ VectorMask[i] = 0;
continue;
}
- // If this input is something other than a EXTRACT_VECTOR_ELT with a
- // constant index, bail out.
+ // Not an undef or zero. If the input is something other than an
+ // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- !isa<ConstantSDNode>(Op.getOperand(1))) {
- VecIn1 = VecIn2 = SDValue(nullptr, 0);
- break;
- }
-
- // We allow up to two distinct input vectors.
+ !isa<ConstantSDNode>(Op.getOperand(1)))
+ return SDValue();
SDValue ExtractedFromVec = Op.getOperand(0);
- if (ExtractedFromVec == VecIn1 || ExtractedFromVec == VecIn2)
- continue;
- if (!VecIn1.getNode()) {
- VecIn1 = ExtractedFromVec;
- } else if (!VecIn2.getNode() && !UsesZeroVector) {
- VecIn2 = ExtractedFromVec;
- } else {
- // Too many inputs.
- VecIn1 = VecIn2 = SDValue(nullptr, 0);
- break;
- }
+ APInt ExtractIdx = cast<ConstantSDNode>(Op.getOperand(1))->getAPIntValue();
+ if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
+ return SDValue();
+
+ // All inputs must have the same element type as the output.
+ if (VT.getVectorElementType() !=
+ ExtractedFromVec.getValueType().getVectorElementType())
+ return SDValue();
+
+ // Have we seen this input vector before?
+ // The vectors are expected to be tiny (usually 1 or 2 elements), so using
+ // a map back from SDValues to numbers isn't worth it.
+ unsigned Idx = std::distance(
+ VecIn.begin(), std::find(VecIn.begin(), VecIn.end(), ExtractedFromVec));
+ if (Idx == VecIn.size())
+ VecIn.push_back(ExtractedFromVec);
+
+ VectorMask[i] = Idx;
}
- // If everything is good, we can make a shuffle operation.
- if (VecIn1.getNode()) {
- unsigned InNumElements = VecIn1.getValueType().getVectorNumElements();
- SmallVector<int, 8> Mask;
- for (unsigned i = 0; i != NumInScalars; ++i) {
- unsigned Opcode = N->getOperand(i).getOpcode();
- if (Opcode == ISD::UNDEF) {
- Mask.push_back(-1);
- continue;
- }
+ // If we didn't find at least one input vector, bail out.
+ if (VecIn.size() < 2)
+ return SDValue();
- // Operands can also be zero.
- if (Opcode != ISD::EXTRACT_VECTOR_ELT) {
- assert(UsesZeroVector &&
- (Opcode == ISD::Constant || Opcode == ISD::ConstantFP) &&
- "Unexpected node found!");
- Mask.push_back(NumInScalars+i);
+ // If all the Operands of BUILD_VECTOR extract from same
+ // vector, then split the vector efficiently based on the maximum
+ // vector access index and adjust the VectorMask and
+ // VecIn accordingly.
+ if (VecIn.size() == 2) {
+ unsigned MaxIndex = 0;
+ unsigned NearestPow2 = 0;
+ SDValue Vec = VecIn.back();
+ EVT InVT = Vec.getValueType();
+ MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
+ SmallVector<unsigned, 8> IndexVec(NumElems, 0);
+
+ for (unsigned i = 0; i < NumElems; i++) {
+ if (VectorMask[i] <= 0)
continue;
+ unsigned Index = N->getOperand(i).getConstantOperandVal(1);
+ IndexVec[i] = Index;
+ MaxIndex = std::max(MaxIndex, Index);
+ }
+
+ NearestPow2 = PowerOf2Ceil(MaxIndex);
+ if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
+ NumElems * 2 < NearestPow2) {
+ unsigned SplitSize = NearestPow2 / 2;
+ EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
+ InVT.getVectorElementType(), SplitSize);
+ if (TLI.isTypeLegal(SplitVT)) {
+ SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
+ DAG.getConstant(SplitSize, DL, IdxTy));
+ SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
+ DAG.getConstant(0, DL, IdxTy));
+ VecIn.pop_back();
+ VecIn.push_back(VecIn1);
+ VecIn.push_back(VecIn2);
+
+ for (unsigned i = 0; i < NumElems; i++) {
+ if (VectorMask[i] <= 0)
+ continue;
+ VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
+ }
}
+ }
+ }
- // If extracting from the first vector, just use the index directly.
- SDValue Extract = N->getOperand(i);
- SDValue ExtVal = Extract.getOperand(1);
- unsigned ExtIndex = cast<ConstantSDNode>(ExtVal)->getZExtValue();
- if (Extract.getOperand(0) == VecIn1) {
- Mask.push_back(ExtIndex);
- continue;
+ // TODO: We want to sort the vectors by descending length, so that adjacent
+ // pairs have similar length, and the longer vector is always first in the
+ // pair.
+
+ // TODO: Should this fire if some of the input vectors has illegal type (like
+ // it does now), or should we let legalization run its course first?
+
+ // Shuffle phase:
+ // Take pairs of vectors, and shuffle them so that the result has elements
+ // from these vectors in the correct places.
+ // For example, given:
+ // t10: i32 = extract_vector_elt t1, Constant:i64<0>
+ // t11: i32 = extract_vector_elt t2, Constant:i64<0>
+ // t12: i32 = extract_vector_elt t3, Constant:i64<0>
+ // t13: i32 = extract_vector_elt t1, Constant:i64<1>
+ // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
+ // We will generate:
+ // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
+ // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
+ SmallVector<SDValue, 4> Shuffles;
+ for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
+ unsigned LeftIdx = 2 * In + 1;
+ SDValue VecLeft = VecIn[LeftIdx];
+ SDValue VecRight =
+ (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
+
+ if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
+ VecRight, LeftIdx))
+ Shuffles.push_back(Shuffle);
+ else
+ return SDValue();
+ }
+
+ // If we need the zero vector as an "ingredient" in the blend tree, add it
+ // to the list of shuffles.
+ if (UsesZeroVector)
+ Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
+ : DAG.getConstantFP(0.0, DL, VT));
+
+ // If we only have one shuffle, we're done.
+ if (Shuffles.size() == 1)
+ return Shuffles[0];
+
+ // Update the vector mask to point to the post-shuffle vectors.
+ for (int &Vec : VectorMask)
+ if (Vec == 0)
+ Vec = Shuffles.size() - 1;
+ else
+ Vec = (Vec - 1) / 2;
+
+ // More than one shuffle. Generate a binary tree of blends, e.g. if from
+ // the previous step we got the set of shuffles t10, t11, t12, t13, we will
+ // generate:
+ // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
+ // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
+ // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
+ // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
+ // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
+ // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
+ // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
+
+ // Make sure the initial size of the shuffle list is even.
+ if (Shuffles.size() % 2)
+ Shuffles.push_back(DAG.getUNDEF(VT));
+
+ for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
+ if (CurSize % 2) {
+ Shuffles[CurSize] = DAG.getUNDEF(VT);
+ CurSize++;
+ }
+ for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
+ int Left = 2 * In;
+ int Right = 2 * In + 1;
+ SmallVector<int, 8> Mask(NumElems, -1);
+ for (unsigned i = 0; i != NumElems; ++i) {
+ if (VectorMask[i] == Left) {
+ Mask[i] = i;
+ VectorMask[i] = In;
+ } else if (VectorMask[i] == Right) {
+ Mask[i] = i + NumElems;
+ VectorMask[i] = In;
+ }
}
- // Otherwise, use InIdx + InputVecSize
- Mask.push_back(InNumElements + ExtIndex);
+ Shuffles[In] =
+ DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask);
}
+ }
+ return Shuffles[0];
+}
- // Avoid introducing illegal shuffles with zero.
- if (UsesZeroVector && !TLI.isVectorClearMaskLegal(Mask, VT))
+// Try to turn a build vector of zero extends of extract vector elts into a
+// a vector zero extend and possibly an extract subvector.
+// TODO: Support sign extend or any extend?
+// TODO: Allow undef elements?
+// TODO: Don't require the extracts to start at element 0.
+SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
+ if (LegalOperations)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+
+ SDValue Op0 = N->getOperand(0);
+ auto checkElem = [&](SDValue Op) -> int64_t {
+ if (Op.getOpcode() == ISD::ZERO_EXTEND &&
+ Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
+ if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
+ return C->getZExtValue();
+ return -1;
+ };
+
+ // Make sure the first element matches
+ // (zext (extract_vector_elt X, C))
+ int64_t Offset = checkElem(Op0);
+ if (Offset < 0)
+ return SDValue();
+
+ unsigned NumElems = N->getNumOperands();
+ SDValue In = Op0.getOperand(0).getOperand(0);
+ EVT InSVT = In.getValueType().getScalarType();
+ EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
+
+ // Don't create an illegal input type after type legalization.
+ if (LegalTypes && !TLI.isTypeLegal(InVT))
+ return SDValue();
+
+ // Ensure all the elements come from the same vector and are adjacent.
+ for (unsigned i = 1; i != NumElems; ++i) {
+ if ((Offset + i) != checkElem(N->getOperand(i)))
return SDValue();
+ }
- // We can't generate a shuffle node with mismatched input and output types.
- // Attempt to transform a single input vector to the correct type.
- if ((VT != VecIn1.getValueType())) {
- // If the input vector type has a different base type to the output
- // vector type, bail out.
- EVT VTElemType = VT.getVectorElementType();
- if ((VecIn1.getValueType().getVectorElementType() != VTElemType) ||
- (VecIn2.getNode() &&
- (VecIn2.getValueType().getVectorElementType() != VTElemType)))
- return SDValue();
+ SDLoc DL(N);
+ In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
+ Op0.getOperand(0).getOperand(1));
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, In);
+}
- // If the input vector is too small, widen it.
- // We only support widening of vectors which are half the size of the
- // output registers. For example XMM->YMM widening on X86 with AVX.
- EVT VecInT = VecIn1.getValueType();
- if (VecInT.getSizeInBits() * 2 == VT.getSizeInBits()) {
- // If we only have one small input, widen it by adding undef values.
- if (!VecIn2.getNode())
- VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, VecIn1,
- DAG.getUNDEF(VecIn1.getValueType()));
- else if (VecIn1.getValueType() == VecIn2.getValueType()) {
- // If we have two small inputs of the same type, try to concat them.
- VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, VecIn1, VecIn2);
- VecIn2 = SDValue(nullptr, 0);
- } else
- return SDValue();
- } else if (VecInT.getSizeInBits() == VT.getSizeInBits() * 2) {
- // If the input vector is too large, try to split it.
- // We don't support having two input vectors that are too large.
- // If the zero vector was used, we can not split the vector,
- // since we'd need 3 inputs.
- if (UsesZeroVector || VecIn2.getNode())
- return SDValue();
+SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
+ EVT VT = N->getValueType(0);
- if (!TLI.isExtractSubvectorCheap(VT, VT.getVectorNumElements()))
- return SDValue();
+ // A vector built entirely of undefs is undef.
+ if (ISD::allOperandsUndef(N))
+ return DAG.getUNDEF(VT);
- // Try to replace VecIn1 with two extract_subvectors
- // No need to update the masks, they should still be correct.
- VecIn2 = DAG.getNode(
- ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1,
- DAG.getConstant(VT.getVectorNumElements(), dl,
- TLI.getVectorIdxTy(DAG.getDataLayout())));
- VecIn1 = DAG.getNode(
- ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1,
- DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout())));
- } else
- return SDValue();
+ // If this is a splat of a bitcast from another vector, change to a
+ // concat_vector.
+ // For example:
+ // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
+ // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
+ //
+ // If X is a build_vector itself, the concat can become a larger build_vector.
+ // TODO: Maybe this is useful for non-splat too?
+ if (!LegalOperations) {
+ if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
+ Splat = peekThroughBitcasts(Splat);
+ EVT SrcVT = Splat.getValueType();
+ if (SrcVT.isVector()) {
+ unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
+ EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
+ SrcVT.getVectorElementType(), NumElts);
+ if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
+ SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
+ SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
+ NewVT, Ops);
+ return DAG.getBitcast(VT, Concat);
+ }
+ }
}
+ }
- if (UsesZeroVector)
- VecIn2 = VT.isInteger() ? DAG.getConstant(0, dl, VT) :
- DAG.getConstantFP(0.0, dl, VT);
- else
- // If VecIn2 is unused then change it to undef.
- VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT);
+ // Check if we can express BUILD VECTOR via subvector extract.
+ if (!LegalTypes && (N->getNumOperands() > 1)) {
+ SDValue Op0 = N->getOperand(0);
+ auto checkElem = [&](SDValue Op) -> uint64_t {
+ if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
+ (Op0.getOperand(0) == Op.getOperand(0)))
+ if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
+ return CNode->getZExtValue();
+ return -1;
+ };
- // Check that we were able to transform all incoming values to the same
- // type.
- if (VecIn2.getValueType() != VecIn1.getValueType() ||
- VecIn1.getValueType() != VT)
- return SDValue();
+ int Offset = checkElem(Op0);
+ for (unsigned i = 0; i < N->getNumOperands(); ++i) {
+ if (Offset + i != checkElem(N->getOperand(i))) {
+ Offset = -1;
+ break;
+ }
+ }
- // Return the new VECTOR_SHUFFLE node.
- SDValue Ops[2];
- Ops[0] = VecIn1;
- Ops[1] = VecIn2;
- return DAG.getVectorShuffle(VT, dl, Ops[0], Ops[1], &Mask[0]);
+ if ((Offset == 0) &&
+ (Op0.getOperand(0).getValueType() == N->getValueType(0)))
+ return Op0.getOperand(0);
+ if ((Offset != -1) &&
+ ((Offset % N->getValueType(0).getVectorNumElements()) ==
+ 0)) // IDX must be multiple of output size.
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
+ Op0.getOperand(0), Op0.getOperand(1));
}
+ if (SDValue V = convertBuildVecZextToZext(N))
+ return V;
+
+ if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
+ return V;
+
+ if (SDValue V = reduceBuildVecToShuffle(N))
+ return V;
+
return SDValue();
}
@@ -12751,18 +16653,17 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
for (SDValue &Op : Ops) {
if (Op.getValueType() == SVT)
continue;
- if (Op.getOpcode() == ISD::UNDEF)
+ if (Op.isUndef())
Op = ScalarUndef;
else
- Op = DAG.getNode(ISD::BITCAST, DL, SVT, Op);
+ Op = DAG.getBitcast(SVT, Op);
}
}
}
EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
VT.getSizeInBits() / SVT.getSizeInBits());
- return DAG.getNode(ISD::BITCAST, DL, VT,
- DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Ops));
+ return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
}
// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
@@ -12779,12 +16680,10 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
SmallVector<int, 8> Mask;
for (SDValue Op : N->ops()) {
- // Peek through any bitcast.
- while (Op.getOpcode() == ISD::BITCAST)
- Op = Op.getOperand(0);
+ Op = peekThroughBitcasts(Op);
// UNDEF nodes convert to UNDEF shuffle mask values.
- if (Op.getOpcode() == ISD::UNDEF) {
+ if (Op.isUndef()) {
Mask.append((unsigned)NumOpElts, -1);
continue;
}
@@ -12798,20 +16697,17 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
// We want the EVT of the original extraction to correctly scale the
// extraction index.
EVT ExtVT = ExtVec.getValueType();
-
- // Peek through any bitcast.
- while (ExtVec.getOpcode() == ISD::BITCAST)
- ExtVec = ExtVec.getOperand(0);
+ ExtVec = peekThroughBitcasts(ExtVec);
// UNDEF nodes convert to UNDEF shuffle mask values.
- if (ExtVec.getOpcode() == ISD::UNDEF) {
+ if (ExtVec.isUndef()) {
Mask.append((unsigned)NumOpElts, -1);
continue;
}
if (!isa<ConstantSDNode>(Op.getOperand(1)))
return SDValue();
- int ExtIdx = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue();
+ int ExtIdx = Op.getConstantOperandVal(1);
// Ensure that we are extracting a subvector from a vector the same
// size as the result.
@@ -12828,11 +16724,11 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
return SDValue();
// At most we can reference 2 inputs in the final shuffle.
- if (SV0.getOpcode() == ISD::UNDEF || SV0 == ExtVec) {
+ if (SV0.isUndef() || SV0 == ExtVec) {
SV0 = ExtVec;
for (int i = 0; i != NumOpElts; ++i)
Mask.push_back(i + ExtIdx);
- } else if (SV1.getOpcode() == ISD::UNDEF || SV1 == ExtVec) {
+ } else if (SV1.isUndef() || SV1 == ExtVec) {
SV1 = ExtVec;
for (int i = 0; i != NumOpElts; ++i)
Mask.push_back(i + ExtIdx + NumElts);
@@ -12860,16 +16756,24 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
// Optimize concat_vectors where all but the first of the vectors are undef.
if (std::all_of(std::next(N->op_begin()), N->op_end(), [](const SDValue &Op) {
- return Op.getOpcode() == ISD::UNDEF;
+ return Op.isUndef();
})) {
SDValue In = N->getOperand(0);
assert(In.getValueType().isVector() && "Must concat vectors");
- // Transform: concat_vectors(scalar, undef) -> scalar_to_vector(sclr).
- if (In->getOpcode() == ISD::BITCAST &&
- !In->getOperand(0)->getValueType(0).isVector()) {
- SDValue Scalar = In->getOperand(0);
+ SDValue Scalar = peekThroughOneUseBitcasts(In);
+ // concat_vectors(scalar_to_vector(scalar), undef) ->
+ // scalar_to_vector(scalar)
+ if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
+ Scalar.hasOneUse()) {
+ EVT SVT = Scalar.getValueType().getVectorElementType();
+ if (SVT == Scalar.getOperand(0).getValueType())
+ Scalar = Scalar.getOperand(0);
+ }
+
+ // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
+ if (!Scalar.getValueType().isVector()) {
// If the bitcast type isn't legal, it might be a trunc of a legal type;
// look through the trunc so we can still do the transform:
// concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
@@ -12878,19 +16782,25 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
Scalar = Scalar->getOperand(0);
- EVT SclTy = Scalar->getValueType(0);
+ EVT SclTy = Scalar.getValueType();
if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
return SDValue();
- EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy,
- VT.getSizeInBits() / SclTy.getSizeInBits());
+ // Bail out if the vector size is not a multiple of the scalar size.
+ if (VT.getSizeInBits() % SclTy.getSizeInBits())
+ return SDValue();
+
+ unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
+ if (VNTNumElms < 2)
+ return SDValue();
+
+ EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
return SDValue();
- SDLoc dl = SDLoc(N);
- SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NVT, Scalar);
- return DAG.getNode(ISD::BITCAST, dl, VT, Res);
+ SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
+ return DAG.getBitcast(VT, Res);
}
}
@@ -12901,9 +16811,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
auto IsBuildVectorOrUndef = [](const SDValue &Op) {
return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
};
- bool AllBuildVectorsOrUndefs =
- std::all_of(N->op_begin(), N->op_end(), IsBuildVectorOrUndef);
- if (AllBuildVectorsOrUndefs) {
+ if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
SmallVector<SDValue, 8> Opnds;
EVT SVT = VT.getScalarType();
@@ -12914,7 +16822,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
bool FoundMinVT = false;
for (const SDValue &Op : N->ops())
if (ISD::BUILD_VECTOR == Op.getOpcode()) {
- EVT OpSVT = Op.getOperand(0)->getValueType(0);
+ EVT OpSVT = Op.getOperand(0).getValueType();
MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
FoundMinVT = true;
}
@@ -12942,7 +16850,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
assert(VT.getVectorNumElements() == Opnds.size() &&
"Concat vector type mismatch");
- return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Opnds);
+ return DAG.getBuildVector(VT, SDLoc(N), Opnds);
}
// Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
@@ -12964,7 +16872,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
SDValue Op = N->getOperand(i);
- if (Op.getOpcode() == ISD::UNDEF)
+ if (Op.isUndef())
continue;
// Check if this is the identity extract:
@@ -13002,19 +16910,177 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
return SDValue();
}
+/// If we are extracting a subvector produced by a wide binary operator try
+/// to use a narrow binary operator and/or avoid concatenation and extraction.
+static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) {
+ // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
+ // some of these bailouts with other transforms.
+
+ // The extract index must be a constant, so we can map it to a concat operand.
+ auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
+ if (!ExtractIndexC)
+ return SDValue();
+
+ // We are looking for an optionally bitcasted wide vector binary operator
+ // feeding an extract subvector.
+ SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
+ if (!ISD::isBinaryOp(BinOp.getNode()))
+ return SDValue();
+
+ // The binop must be a vector type, so we can extract some fraction of it.
+ EVT WideBVT = BinOp.getValueType();
+ if (!WideBVT.isVector())
+ return SDValue();
+
+ EVT VT = Extract->getValueType(0);
+ unsigned ExtractIndex = ExtractIndexC->getZExtValue();
+ assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
+ "Extract index is not a multiple of the vector length.");
+
+ // Bail out if this is not a proper multiple width extraction.
+ unsigned WideWidth = WideBVT.getSizeInBits();
+ unsigned NarrowWidth = VT.getSizeInBits();
+ if (WideWidth % NarrowWidth != 0)
+ return SDValue();
+
+ // Bail out if we are extracting a fraction of a single operation. This can
+ // occur because we potentially looked through a bitcast of the binop.
+ unsigned NarrowingRatio = WideWidth / NarrowWidth;
+ unsigned WideNumElts = WideBVT.getVectorNumElements();
+ if (WideNumElts % NarrowingRatio != 0)
+ return SDValue();
+
+ // Bail out if the target does not support a narrower version of the binop.
+ EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
+ WideNumElts / NarrowingRatio);
+ unsigned BOpcode = BinOp.getOpcode();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
+ return SDValue();
+
+ // If extraction is cheap, we don't need to look at the binop operands
+ // for concat ops. The narrow binop alone makes this transform profitable.
+ // We can't just reuse the original extract index operand because we may have
+ // bitcasted.
+ unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
+ unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
+ EVT ExtBOIdxVT = Extract->getOperand(1).getValueType();
+ if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
+ BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
+ // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
+ SDLoc DL(Extract);
+ SDValue NewExtIndex = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT);
+ SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
+ BinOp.getOperand(0), NewExtIndex);
+ SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
+ BinOp.getOperand(1), NewExtIndex);
+ SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y,
+ BinOp.getNode()->getFlags());
+ return DAG.getBitcast(VT, NarrowBinOp);
+ }
+
+ // Only handle the case where we are doubling and then halving. A larger ratio
+ // may require more than two narrow binops to replace the wide binop.
+ if (NarrowingRatio != 2)
+ return SDValue();
+
+ // TODO: The motivating case for this transform is an x86 AVX1 target. That
+ // target has temptingly almost legal versions of bitwise logic ops in 256-bit
+ // flavors, but no other 256-bit integer support. This could be extended to
+ // handle any binop, but that may require fixing/adding other folds to avoid
+ // codegen regressions.
+ if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
+ return SDValue();
+
+ // We need at least one concatenation operation of a binop operand to make
+ // this transform worthwhile. The concat must double the input vector sizes.
+ // TODO: Should we also handle INSERT_SUBVECTOR patterns?
+ SDValue LHS = peekThroughBitcasts(BinOp.getOperand(0));
+ SDValue RHS = peekThroughBitcasts(BinOp.getOperand(1));
+ bool ConcatL =
+ LHS.getOpcode() == ISD::CONCAT_VECTORS && LHS.getNumOperands() == 2;
+ bool ConcatR =
+ RHS.getOpcode() == ISD::CONCAT_VECTORS && RHS.getNumOperands() == 2;
+ if (!ConcatL && !ConcatR)
+ return SDValue();
+
+ // If one of the binop operands was not the result of a concat, we must
+ // extract a half-sized operand for our new narrow binop.
+ SDLoc DL(Extract);
+
+ // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
+ // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, N)
+ // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, N), YN
+ SDValue X = ConcatL ? DAG.getBitcast(NarrowBVT, LHS.getOperand(ConcatOpNum))
+ : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
+ BinOp.getOperand(0),
+ DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT));
+
+ SDValue Y = ConcatR ? DAG.getBitcast(NarrowBVT, RHS.getOperand(ConcatOpNum))
+ : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
+ BinOp.getOperand(1),
+ DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT));
+
+ SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
+ return DAG.getBitcast(VT, NarrowBinOp);
+}
+
+/// If we are extracting a subvector from a wide vector load, convert to a
+/// narrow load to eliminate the extraction:
+/// (extract_subvector (load wide vector)) --> (load narrow vector)
+static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
+ // TODO: Add support for big-endian. The offset calculation must be adjusted.
+ if (DAG.getDataLayout().isBigEndian())
+ return SDValue();
+
+ auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
+ auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
+ if (!Ld || Ld->getExtensionType() || Ld->isVolatile() || !ExtIdx)
+ return SDValue();
+
+ // Allow targets to opt-out.
+ EVT VT = Extract->getValueType(0);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
+ return SDValue();
+
+ // The narrow load will be offset from the base address of the old load if
+ // we are extracting from something besides index 0 (little-endian).
+ SDLoc DL(Extract);
+ SDValue BaseAddr = Ld->getOperand(1);
+ unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize();
+
+ // TODO: Use "BaseIndexOffset" to make this more effective.
+ SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
+ MachineFunction &MF = DAG.getMachineFunction();
+ MachineMemOperand *MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset,
+ VT.getStoreSize());
+ SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
+ DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
+ return NewLd;
+}
+
SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
EVT NVT = N->getValueType(0);
SDValue V = N->getOperand(0);
- if (V->getOpcode() == ISD::CONCAT_VECTORS) {
- // Combine:
- // (extract_subvec (concat V1, V2, ...), i)
- // Into:
- // Vi if possible
- // Only operand 0 is checked as 'concat' assumes all inputs of the same
- // type.
- if (V->getOperand(0).getValueType() != NVT)
- return SDValue();
+ // Extract from UNDEF is UNDEF.
+ if (V.isUndef())
+ return DAG.getUNDEF(NVT);
+
+ if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
+ if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
+ return NarrowLoad;
+
+ // Combine:
+ // (extract_subvec (concat V1, V2, ...), i)
+ // Into:
+ // Vi if possible
+ // Only operand 0 is checked as 'concat' assumes all inputs of the same
+ // type.
+ if (V.getOpcode() == ISD::CONCAT_VECTORS &&
+ isa<ConstantSDNode>(N->getOperand(1)) &&
+ V.getOperand(0).getValueType() == NVT) {
unsigned Idx = N->getConstantOperandVal(1);
unsigned NumElems = NVT.getVectorNumElements();
assert((Idx % NumElems) == 0 &&
@@ -13022,128 +17088,78 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
return V->getOperand(Idx / NumElems);
}
- // Skip bitcasting
- if (V->getOpcode() == ISD::BITCAST)
- V = V.getOperand(0);
+ V = peekThroughBitcasts(V);
+
+ // If the input is a build vector. Try to make a smaller build vector.
+ if (V.getOpcode() == ISD::BUILD_VECTOR) {
+ if (auto *Idx = dyn_cast<ConstantSDNode>(N->getOperand(1))) {
+ EVT InVT = V.getValueType();
+ unsigned ExtractSize = NVT.getSizeInBits();
+ unsigned EltSize = InVT.getScalarSizeInBits();
+ // Only do this if we won't split any elements.
+ if (ExtractSize % EltSize == 0) {
+ unsigned NumElems = ExtractSize / EltSize;
+ EVT EltVT = InVT.getVectorElementType();
+ EVT ExtractVT = NumElems == 1 ? EltVT :
+ EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
+ if ((Level < AfterLegalizeDAG ||
+ (NumElems == 1 ||
+ TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
+ (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
+ unsigned IdxVal = (Idx->getZExtValue() * NVT.getScalarSizeInBits()) /
+ EltSize;
+ if (NumElems == 1) {
+ SDValue Src = V->getOperand(IdxVal);
+ if (EltVT != Src.getValueType())
+ Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
+
+ return DAG.getBitcast(NVT, Src);
+ }
+
+ // Extract the pieces from the original build_vector.
+ SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
+ makeArrayRef(V->op_begin() + IdxVal,
+ NumElems));
+ return DAG.getBitcast(NVT, BuildVec);
+ }
+ }
+ }
+ }
- if (V->getOpcode() == ISD::INSERT_SUBVECTOR) {
- SDLoc dl(N);
+ if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
// Handle only simple case where vector being inserted and vector
- // being extracted are of same type, and are half size of larger vectors.
- EVT BigVT = V->getOperand(0).getValueType();
- EVT SmallVT = V->getOperand(1).getValueType();
- if (!NVT.bitsEq(SmallVT) || NVT.getSizeInBits()*2 != BigVT.getSizeInBits())
+ // being extracted are of same size.
+ EVT SmallVT = V.getOperand(1).getValueType();
+ if (!NVT.bitsEq(SmallVT))
return SDValue();
- // Only handle cases where both indexes are constants with the same type.
- ConstantSDNode *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1));
- ConstantSDNode *InsIdx = dyn_cast<ConstantSDNode>(V->getOperand(2));
+ // Only handle cases where both indexes are constants.
+ auto *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ auto *InsIdx = dyn_cast<ConstantSDNode>(V.getOperand(2));
- if (InsIdx && ExtIdx &&
- InsIdx->getValueType(0).getSizeInBits() <= 64 &&
- ExtIdx->getValueType(0).getSizeInBits() <= 64) {
+ if (InsIdx && ExtIdx) {
// Combine:
// (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
// Into:
// indices are equal or bit offsets are equal => V1
// otherwise => (extract_subvec V1, ExtIdx)
- if (InsIdx->getZExtValue() * SmallVT.getScalarType().getSizeInBits() ==
- ExtIdx->getZExtValue() * NVT.getScalarType().getSizeInBits())
- return DAG.getNode(ISD::BITCAST, dl, NVT, V->getOperand(1));
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT,
- DAG.getNode(ISD::BITCAST, dl,
- N->getOperand(0).getValueType(),
- V->getOperand(0)), N->getOperand(1));
+ if (InsIdx->getZExtValue() * SmallVT.getScalarSizeInBits() ==
+ ExtIdx->getZExtValue() * NVT.getScalarSizeInBits())
+ return DAG.getBitcast(NVT, V.getOperand(1));
+ return DAG.getNode(
+ ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
+ DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
+ N->getOperand(1));
}
}
- return SDValue();
-}
-
-static SDValue simplifyShuffleOperandRecursively(SmallBitVector &UsedElements,
- SDValue V, SelectionDAG &DAG) {
- SDLoc DL(V);
- EVT VT = V.getValueType();
-
- switch (V.getOpcode()) {
- default:
- return V;
-
- case ISD::CONCAT_VECTORS: {
- EVT OpVT = V->getOperand(0).getValueType();
- int OpSize = OpVT.getVectorNumElements();
- SmallBitVector OpUsedElements(OpSize, false);
- bool FoundSimplification = false;
- SmallVector<SDValue, 4> NewOps;
- NewOps.reserve(V->getNumOperands());
- for (int i = 0, NumOps = V->getNumOperands(); i < NumOps; ++i) {
- SDValue Op = V->getOperand(i);
- bool OpUsed = false;
- for (int j = 0; j < OpSize; ++j)
- if (UsedElements[i * OpSize + j]) {
- OpUsedElements[j] = true;
- OpUsed = true;
- }
- NewOps.push_back(
- OpUsed ? simplifyShuffleOperandRecursively(OpUsedElements, Op, DAG)
- : DAG.getUNDEF(OpVT));
- FoundSimplification |= Op == NewOps.back();
- OpUsedElements.reset();
- }
- if (FoundSimplification)
- V = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, NewOps);
- return V;
- }
-
- case ISD::INSERT_SUBVECTOR: {
- SDValue BaseV = V->getOperand(0);
- SDValue SubV = V->getOperand(1);
- auto *IdxN = dyn_cast<ConstantSDNode>(V->getOperand(2));
- if (!IdxN)
- return V;
-
- int SubSize = SubV.getValueType().getVectorNumElements();
- int Idx = IdxN->getZExtValue();
- bool SubVectorUsed = false;
- SmallBitVector SubUsedElements(SubSize, false);
- for (int i = 0; i < SubSize; ++i)
- if (UsedElements[i + Idx]) {
- SubVectorUsed = true;
- SubUsedElements[i] = true;
- UsedElements[i + Idx] = false;
- }
-
- // Now recurse on both the base and sub vectors.
- SDValue SimplifiedSubV =
- SubVectorUsed
- ? simplifyShuffleOperandRecursively(SubUsedElements, SubV, DAG)
- : DAG.getUNDEF(SubV.getValueType());
- SDValue SimplifiedBaseV = simplifyShuffleOperandRecursively(UsedElements, BaseV, DAG);
- if (SimplifiedSubV != SubV || SimplifiedBaseV != BaseV)
- V = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
- SimplifiedBaseV, SimplifiedSubV, V->getOperand(2));
- return V;
- }
- }
-}
+ if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG))
+ return NarrowBOp;
-static SDValue simplifyShuffleOperands(ShuffleVectorSDNode *SVN, SDValue N0,
- SDValue N1, SelectionDAG &DAG) {
- EVT VT = SVN->getValueType(0);
- int NumElts = VT.getVectorNumElements();
- SmallBitVector N0UsedElements(NumElts, false), N1UsedElements(NumElts, false);
- for (int M : SVN->getMask())
- if (M >= 0 && M < NumElts)
- N0UsedElements[M] = true;
- else if (M >= NumElts)
- N1UsedElements[M - NumElts] = true;
-
- SDValue S0 = simplifyShuffleOperandRecursively(N0UsedElements, N0, DAG);
- SDValue S1 = simplifyShuffleOperandRecursively(N1UsedElements, N1, DAG);
- if (S0 == N0 && S1 == N1)
- return SDValue();
+ if (SimplifyDemandedVectorElts(SDValue(N, 0)))
+ return SDValue(N, 0);
- return DAG.getVectorShuffle(VT, SDLoc(SVN), S0, S1, SVN->getMask());
+ return SDValue();
}
// Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
@@ -13164,7 +17180,7 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
// Special case: shuffle(concat(A,B)) can be more efficiently represented
// as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
// half vector elements.
- if (NumElemsPerConcat * 2 == NumElts && N1.getOpcode() == ISD::UNDEF &&
+ if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
std::all_of(SVN->getMask().begin() + NumElemsPerConcat,
SVN->getMask().end(), [](int i) { return i == -1; })) {
N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0), N0.getOperand(1),
@@ -13210,6 +17226,343 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
}
+// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
+// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
+//
+// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
+// a simplification in some sense, but it isn't appropriate in general: some
+// BUILD_VECTORs are substantially cheaper than others. The general case
+// of a BUILD_VECTOR requires inserting each element individually (or
+// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
+// all constants is a single constant pool load. A BUILD_VECTOR where each
+// element is identical is a splat. A BUILD_VECTOR where most of the operands
+// are undef lowers to a small number of element insertions.
+//
+// To deal with this, we currently use a bunch of mostly arbitrary heuristics.
+// We don't fold shuffles where one side is a non-zero constant, and we don't
+// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
+// non-constant operands. This seems to work out reasonably well in practice.
+static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
+ SelectionDAG &DAG,
+ const TargetLowering &TLI) {
+ EVT VT = SVN->getValueType(0);
+ unsigned NumElts = VT.getVectorNumElements();
+ SDValue N0 = SVN->getOperand(0);
+ SDValue N1 = SVN->getOperand(1);
+
+ if (!N0->hasOneUse())
+ return SDValue();
+
+ // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
+ // discussed above.
+ if (!N1.isUndef()) {
+ if (!N1->hasOneUse())
+ return SDValue();
+
+ bool N0AnyConst = isAnyConstantBuildVector(N0);
+ bool N1AnyConst = isAnyConstantBuildVector(N1);
+ if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
+ return SDValue();
+ if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
+ return SDValue();
+ }
+
+ // If both inputs are splats of the same value then we can safely merge this
+ // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
+ bool IsSplat = false;
+ auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
+ auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
+ if (BV0 && BV1)
+ if (SDValue Splat0 = BV0->getSplatValue())
+ IsSplat = (Splat0 == BV1->getSplatValue());
+
+ SmallVector<SDValue, 8> Ops;
+ SmallSet<SDValue, 16> DuplicateOps;
+ for (int M : SVN->getMask()) {
+ SDValue Op = DAG.getUNDEF(VT.getScalarType());
+ if (M >= 0) {
+ int Idx = M < (int)NumElts ? M : M - NumElts;
+ SDValue &S = (M < (int)NumElts ? N0 : N1);
+ if (S.getOpcode() == ISD::BUILD_VECTOR) {
+ Op = S.getOperand(Idx);
+ } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
+ assert(Idx == 0 && "Unexpected SCALAR_TO_VECTOR operand index.");
+ Op = S.getOperand(0);
+ } else {
+ // Operand can't be combined - bail out.
+ return SDValue();
+ }
+ }
+
+ // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
+ // generating a splat; semantically, this is fine, but it's likely to
+ // generate low-quality code if the target can't reconstruct an appropriate
+ // shuffle.
+ if (!Op.isUndef() && !isa<ConstantSDNode>(Op) && !isa<ConstantFPSDNode>(Op))
+ if (!IsSplat && !DuplicateOps.insert(Op).second)
+ return SDValue();
+
+ Ops.push_back(Op);
+ }
+
+ // BUILD_VECTOR requires all inputs to be of the same type, find the
+ // maximum type and extend them all.
+ EVT SVT = VT.getScalarType();
+ if (SVT.isInteger())
+ for (SDValue &Op : Ops)
+ SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
+ if (SVT != VT.getScalarType())
+ for (SDValue &Op : Ops)
+ Op = TLI.isZExtFree(Op.getValueType(), SVT)
+ ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
+ : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT);
+ return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
+}
+
+// Match shuffles that can be converted to any_vector_extend_in_reg.
+// This is often generated during legalization.
+// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
+// TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
+static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
+ SelectionDAG &DAG,
+ const TargetLowering &TLI,
+ bool LegalOperations) {
+ EVT VT = SVN->getValueType(0);
+ bool IsBigEndian = DAG.getDataLayout().isBigEndian();
+
+ // TODO Add support for big-endian when we have a test case.
+ if (!VT.isInteger() || IsBigEndian)
+ return SDValue();
+
+ unsigned NumElts = VT.getVectorNumElements();
+ unsigned EltSizeInBits = VT.getScalarSizeInBits();
+ ArrayRef<int> Mask = SVN->getMask();
+ SDValue N0 = SVN->getOperand(0);
+
+ // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
+ auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (Mask[i] < 0)
+ continue;
+ if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
+ continue;
+ return false;
+ }
+ return true;
+ };
+
+ // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
+ // power-of-2 extensions as they are the most likely.
+ for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
+ // Check for non power of 2 vector sizes
+ if (NumElts % Scale != 0)
+ continue;
+ if (!isAnyExtend(Scale))
+ continue;
+
+ EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
+ EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
+ // Never create an illegal type. Only create unsupported operations if we
+ // are pre-legalization.
+ if (TLI.isTypeLegal(OutVT))
+ if (!LegalOperations ||
+ TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
+ return DAG.getBitcast(VT,
+ DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
+ SDLoc(SVN), OutVT, N0));
+ }
+
+ return SDValue();
+}
+
+// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
+// each source element of a large type into the lowest elements of a smaller
+// destination type. This is often generated during legalization.
+// If the source node itself was a '*_extend_vector_inreg' node then we should
+// then be able to remove it.
+static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
+ SelectionDAG &DAG) {
+ EVT VT = SVN->getValueType(0);
+ bool IsBigEndian = DAG.getDataLayout().isBigEndian();
+
+ // TODO Add support for big-endian when we have a test case.
+ if (!VT.isInteger() || IsBigEndian)
+ return SDValue();
+
+ SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
+
+ unsigned Opcode = N0.getOpcode();
+ if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
+ Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
+ Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ ArrayRef<int> Mask = SVN->getMask();
+ unsigned NumElts = VT.getVectorNumElements();
+ unsigned EltSizeInBits = VT.getScalarSizeInBits();
+ unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
+ unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
+
+ if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
+ return SDValue();
+ unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
+
+ // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
+ // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
+ // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
+ auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (Mask[i] < 0)
+ continue;
+ if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
+ continue;
+ return false;
+ }
+ return true;
+ };
+
+ // At the moment we just handle the case where we've truncated back to the
+ // same size as before the extension.
+ // TODO: handle more extension/truncation cases as cases arise.
+ if (EltSizeInBits != ExtSrcSizeInBits)
+ return SDValue();
+
+ // We can remove *extend_vector_inreg only if the truncation happens at
+ // the same scale as the extension.
+ if (isTruncate(ExtScale))
+ return DAG.getBitcast(VT, N00);
+
+ return SDValue();
+}
+
+// Combine shuffles of splat-shuffles of the form:
+// shuffle (shuffle V, undef, splat-mask), undef, M
+// If splat-mask contains undef elements, we need to be careful about
+// introducing undef's in the folded mask which are not the result of composing
+// the masks of the shuffles.
+static SDValue combineShuffleOfSplat(ArrayRef<int> UserMask,
+ ShuffleVectorSDNode *Splat,
+ SelectionDAG &DAG) {
+ ArrayRef<int> SplatMask = Splat->getMask();
+ assert(UserMask.size() == SplatMask.size() && "Mask length mismatch");
+
+ // Prefer simplifying to the splat-shuffle, if possible. This is legal if
+ // every undef mask element in the splat-shuffle has a corresponding undef
+ // element in the user-shuffle's mask or if the composition of mask elements
+ // would result in undef.
+ // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
+ // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
+ // In this case it is not legal to simplify to the splat-shuffle because we
+ // may be exposing the users of the shuffle an undef element at index 1
+ // which was not there before the combine.
+ // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
+ // In this case the composition of masks yields SplatMask, so it's ok to
+ // simplify to the splat-shuffle.
+ // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
+ // In this case the composed mask includes all undef elements of SplatMask
+ // and in addition sets element zero to undef. It is safe to simplify to
+ // the splat-shuffle.
+ auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
+ ArrayRef<int> SplatMask) {
+ for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
+ if (UserMask[i] != -1 && SplatMask[i] == -1 &&
+ SplatMask[UserMask[i]] != -1)
+ return false;
+ return true;
+ };
+ if (CanSimplifyToExistingSplat(UserMask, SplatMask))
+ return SDValue(Splat, 0);
+
+ // Create a new shuffle with a mask that is composed of the two shuffles'
+ // masks.
+ SmallVector<int, 32> NewMask;
+ for (int Idx : UserMask)
+ NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
+
+ return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
+ Splat->getOperand(0), Splat->getOperand(1),
+ NewMask);
+}
+
+/// If the shuffle mask is taking exactly one element from the first vector
+/// operand and passing through all other elements from the second vector
+/// operand, return the index of the mask element that is choosing an element
+/// from the first operand. Otherwise, return -1.
+static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
+ int MaskSize = Mask.size();
+ int EltFromOp0 = -1;
+ // TODO: This does not match if there are undef elements in the shuffle mask.
+ // Should we ignore undefs in the shuffle mask instead? The trade-off is
+ // removing an instruction (a shuffle), but losing the knowledge that some
+ // vector lanes are not needed.
+ for (int i = 0; i != MaskSize; ++i) {
+ if (Mask[i] >= 0 && Mask[i] < MaskSize) {
+ // We're looking for a shuffle of exactly one element from operand 0.
+ if (EltFromOp0 != -1)
+ return -1;
+ EltFromOp0 = i;
+ } else if (Mask[i] != i + MaskSize) {
+ // Nothing from operand 1 can change lanes.
+ return -1;
+ }
+ }
+ return EltFromOp0;
+}
+
+/// If a shuffle inserts exactly one element from a source vector operand into
+/// another vector operand and we can access the specified element as a scalar,
+/// then we can eliminate the shuffle.
+static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
+ SelectionDAG &DAG) {
+ // First, check if we are taking one element of a vector and shuffling that
+ // element into another vector.
+ ArrayRef<int> Mask = Shuf->getMask();
+ SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
+ SDValue Op0 = Shuf->getOperand(0);
+ SDValue Op1 = Shuf->getOperand(1);
+ int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
+ if (ShufOp0Index == -1) {
+ // Commute mask and check again.
+ ShuffleVectorSDNode::commuteMask(CommutedMask);
+ ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
+ if (ShufOp0Index == -1)
+ return SDValue();
+ // Commute operands to match the commuted shuffle mask.
+ std::swap(Op0, Op1);
+ Mask = CommutedMask;
+ }
+
+ // The shuffle inserts exactly one element from operand 0 into operand 1.
+ // Now see if we can access that element as a scalar via a real insert element
+ // instruction.
+ // TODO: We can try harder to locate the element as a scalar. Examples: it
+ // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
+ assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
+ "Shuffle mask value must be from operand 0");
+ if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
+ return SDValue();
+
+ auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
+ if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
+ return SDValue();
+
+ // There's an existing insertelement with constant insertion index, so we
+ // don't need to check the legality/profitability of a replacement operation
+ // that differs at most in the constant value. The target should be able to
+ // lower any of those in a similar way. If not, legalization will expand this
+ // to a scalar-to-vector plus shuffle.
+ //
+ // Note that the shuffle may move the scalar from the position that the insert
+ // element used. Therefore, our new insert element occurs at the shuffle's
+ // mask index value, not the insert's index value.
+ // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
+ SDValue NewInsIndex = DAG.getConstant(ShufOp0Index, SDLoc(Shuf),
+ Op0.getOperand(2).getValueType());
+ return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
+ Op1, Op0.getOperand(1), NewInsIndex);
+}
+
SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
EVT VT = N->getValueType(0);
unsigned NumElts = VT.getVectorNumElements();
@@ -13220,7 +17573,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
// Canonicalize shuffle undef, undef -> undef
- if (N0.getOpcode() == ISD::UNDEF && N1.getOpcode() == ISD::UNDEF)
+ if (N0.isUndef() && N1.isUndef())
return DAG.getUNDEF(VT);
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
@@ -13233,29 +17586,15 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
if (Idx >= (int)NumElts) Idx -= NumElts;
NewMask.push_back(Idx);
}
- return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
- &NewMask[0]);
+ return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), NewMask);
}
// Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
- if (N0.getOpcode() == ISD::UNDEF) {
- SmallVector<int, 8> NewMask;
- for (unsigned i = 0; i != NumElts; ++i) {
- int Idx = SVN->getMaskElt(i);
- if (Idx >= 0) {
- if (Idx >= (int)NumElts)
- Idx -= NumElts;
- else
- Idx = -1; // remove reference to lhs
- }
- NewMask.push_back(Idx);
- }
- return DAG.getVectorShuffle(VT, SDLoc(N), N1, DAG.getUNDEF(VT),
- &NewMask[0]);
- }
+ if (N0.isUndef())
+ return DAG.getCommutedVectorShuffle(*SVN);
// Remove references to rhs if it is undef
- if (N1.getOpcode() == ISD::UNDEF) {
+ if (N1.isUndef()) {
bool Changed = false;
SmallVector<int, 8> NewMask;
for (unsigned i = 0; i != NumElts; ++i) {
@@ -13267,9 +17606,17 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
NewMask.push_back(Idx);
}
if (Changed)
- return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, &NewMask[0]);
+ return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
}
+ if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
+ return InsElt;
+
+ // A shuffle of a single vector that is a splat can always be folded.
+ if (auto *N0Shuf = dyn_cast<ShuffleVectorSDNode>(N0))
+ if (N1->isUndef() && N0Shuf->isSplat())
+ return combineShuffleOfSplat(SVN->getMask(), N0Shuf, DAG);
+
// If it is a splat, check if the argument vector is another splat or a
// build_vector.
if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
@@ -13291,7 +17638,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
SDValue Base;
bool AllSame = true;
for (unsigned i = 0; i != NumElts; ++i) {
- if (V->getOperand(i).getOpcode() != ISD::UNDEF) {
+ if (!V->getOperand(i).isUndef()) {
Base = V->getOperand(i);
break;
}
@@ -13312,86 +17659,49 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
// Canonicalize any other splat as a build_vector.
const SDValue &Splatted = V->getOperand(SVN->getSplatIndex());
SmallVector<SDValue, 8> Ops(NumElts, Splatted);
- SDValue NewBV = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N),
- V->getValueType(0), Ops);
+ SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
// We may have jumped through bitcasts, so the type of the
// BUILD_VECTOR may not match the type of the shuffle.
if (V->getValueType(0) != VT)
- NewBV = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, NewBV);
+ NewBV = DAG.getBitcast(VT, NewBV);
return NewBV;
}
}
- // There are various patterns used to build up a vector from smaller vectors,
- // subvectors, or elements. Scan chains of these and replace unused insertions
- // or components with undef.
- if (SDValue S = simplifyShuffleOperands(SVN, N0, N1, DAG))
- return S;
+ // Simplify source operands based on shuffle mask.
+ if (SimplifyDemandedVectorElts(SDValue(N, 0)))
+ return SDValue(N, 0);
+
+ // Match shuffles that can be converted to any_vector_extend_in_reg.
+ if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
+ return V;
+
+ // Combine "truncate_vector_in_reg" style shuffles.
+ if (SDValue V = combineTruncationShuffle(SVN, DAG))
+ return V;
if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
Level < AfterLegalizeVectorOps &&
- (N1.getOpcode() == ISD::UNDEF ||
+ (N1.isUndef() ||
(N1.getOpcode() == ISD::CONCAT_VECTORS &&
N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
- SDValue V = partitionShuffleOfConcats(N, DAG);
-
- if (V.getNode())
+ if (SDValue V = partitionShuffleOfConcats(N, DAG))
return V;
}
// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
- if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
- SmallVector<SDValue, 8> Ops;
- for (int M : SVN->getMask()) {
- SDValue Op = DAG.getUNDEF(VT.getScalarType());
- if (M >= 0) {
- int Idx = M % NumElts;
- SDValue &S = (M < (int)NumElts ? N0 : N1);
- if (S.getOpcode() == ISD::BUILD_VECTOR && S.hasOneUse()) {
- Op = S.getOperand(Idx);
- } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR && S.hasOneUse()) {
- if (Idx == 0)
- Op = S.getOperand(0);
- } else {
- // Operand can't be combined - bail out.
- break;
- }
- }
- Ops.push_back(Op);
- }
- if (Ops.size() == VT.getVectorNumElements()) {
- // BUILD_VECTOR requires all inputs to be of the same type, find the
- // maximum type and extend them all.
- EVT SVT = VT.getScalarType();
- if (SVT.isInteger())
- for (SDValue &Op : Ops)
- SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
- if (SVT != VT.getScalarType())
- for (SDValue &Op : Ops)
- Op = TLI.isZExtFree(Op.getValueType(), SVT)
- ? DAG.getZExtOrTrunc(Op, SDLoc(N), SVT)
- : DAG.getSExtOrTrunc(Op, SDLoc(N), SVT);
- return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Ops);
- }
- }
+ if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
+ if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
+ return Res;
// If this shuffle only has a single input that is a bitcasted shuffle,
// attempt to merge the 2 shuffles and suitably bitcast the inputs/output
// back to their original types.
if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
- N1.getOpcode() == ISD::UNDEF && Level < AfterLegalizeVectorOps &&
+ N1.isUndef() && Level < AfterLegalizeVectorOps &&
TLI.isTypeLegal(VT)) {
-
- // Peek through the bitcast only if there is one user.
- SDValue BC0 = N0;
- while (BC0.getOpcode() == ISD::BITCAST) {
- if (!BC0.hasOneUse())
- break;
- BC0 = BC0.getOperand(0);
- }
-
auto ScaleShuffleMask = [](ArrayRef<int> Mask, int Scale) {
if (Scale == 1)
return SmallVector<int, 8>(Mask.begin(), Mask.end());
@@ -13402,7 +17712,8 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
NewMask.push_back(M < 0 ? -1 : Scale * M + s);
return NewMask;
};
-
+
+ SDValue BC0 = peekThroughOneUseBitcasts(N0);
if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
EVT SVT = VT.getScalarType();
EVT InnerVT = BC0->getValueType(0);
@@ -13415,7 +17726,6 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
if (TLI.isTypeLegal(ScaleVT) &&
0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
-
int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
@@ -13442,11 +17752,10 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
}
if (LegalMask) {
- SV0 = DAG.getNode(ISD::BITCAST, SDLoc(N), ScaleVT, SV0);
- SV1 = DAG.getNode(ISD::BITCAST, SDLoc(N), ScaleVT, SV1);
- return DAG.getNode(
- ISD::BITCAST, SDLoc(N), VT,
- DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
+ SV0 = DAG.getBitcast(ScaleVT, SV0);
+ SV1 = DAG.getBitcast(ScaleVT, SV1);
+ return DAG.getBitcast(
+ VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
}
}
}
@@ -13467,7 +17776,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
SDValue SV0 = N1->getOperand(0);
SDValue SV1 = N1->getOperand(1);
bool HasSameOp0 = N0 == SV0;
- bool IsSV1Undef = SV1.getOpcode() == ISD::UNDEF;
+ bool IsSV1Undef = SV1.isUndef();
if (HasSameOp0 || IsSV1Undef || N0 == SV1)
// Commute the operands of this shuffle so that next rule
// will trigger.
@@ -13484,6 +17793,11 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
+ // Don't try to fold splats; they're likely to simplify somehow, or they
+ // might be free.
+ if (OtherSV->isSplat())
+ return SDValue();
+
// The incoming shuffle must be of the same type as the result of the
// current shuffle.
assert(OtherSV->getOperand(0).getValueType() == VT &&
@@ -13520,7 +17834,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
}
// Simple case where 'CurrentVec' is UNDEF.
- if (CurrentVec.getOpcode() == ISD::UNDEF) {
+ if (CurrentVec.isUndef()) {
Mask.push_back(-1);
continue;
}
@@ -13575,7 +17889,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
// shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
// shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
// shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
- return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, &Mask[0]);
+ return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask);
}
return SDValue();
@@ -13586,23 +17900,46 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
EVT VT = N->getValueType(0);
// Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
- // with a VECTOR_SHUFFLE.
+ // with a VECTOR_SHUFFLE and possible truncate.
if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
SDValue InVec = InVal->getOperand(0);
SDValue EltNo = InVal->getOperand(1);
-
- // FIXME: We could support implicit truncation if the shuffle can be
- // scaled to a smaller vector scalar type.
- ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo);
- if (C0 && VT == InVec.getValueType() &&
- VT.getScalarType() == InVal.getValueType()) {
- SmallVector<int, 8> NewMask(VT.getVectorNumElements(), -1);
+ auto InVecT = InVec.getValueType();
+ if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
+ SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
int Elt = C0->getZExtValue();
NewMask[0] = Elt;
-
- if (TLI.isShuffleMaskLegal(NewMask, VT))
- return DAG.getVectorShuffle(VT, SDLoc(N), InVec, DAG.getUNDEF(VT),
- NewMask);
+ SDValue Val;
+ // If we have an implict truncate do truncate here as long as it's legal.
+ // if it's not legal, this should
+ if (VT.getScalarType() != InVal.getValueType() &&
+ InVal.getValueType().isScalarInteger() &&
+ isTypeLegal(VT.getScalarType())) {
+ Val =
+ DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
+ return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
+ }
+ if (VT.getScalarType() == InVecT.getScalarType() &&
+ VT.getVectorNumElements() <= InVecT.getVectorNumElements() &&
+ TLI.isShuffleMaskLegal(NewMask, VT)) {
+ Val = DAG.getVectorShuffle(InVecT, SDLoc(N), InVec,
+ DAG.getUNDEF(InVecT), NewMask);
+ // If the initial vector is the correct size this shuffle is a
+ // valid result.
+ if (VT == InVecT)
+ return Val;
+ // If not we must truncate the vector.
+ if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
+ MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
+ SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy);
+ EVT SubVT =
+ EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(),
+ VT.getVectorNumElements());
+ Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, Val,
+ ZeroIdx);
+ return Val;
+ }
+ }
}
}
@@ -13610,29 +17947,109 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
}
SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
+ EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
+ // If inserting an UNDEF, just return the original vector.
+ if (N1.isUndef())
+ return N0;
+
+ // If this is an insert of an extracted vector into an undef vector, we can
+ // just use the input to the extract.
+ if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
+ return N1.getOperand(0);
+
+ // If we are inserting a bitcast value into an undef, with the same
+ // number of elements, just use the bitcast input of the extract.
+ // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
+ // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
+ if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
+ N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ N1.getOperand(0).getOperand(1) == N2 &&
+ N1.getOperand(0).getOperand(0).getValueType().getVectorNumElements() ==
+ VT.getVectorNumElements() &&
+ N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
+ VT.getSizeInBits()) {
+ return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
+ }
+
+ // If both N1 and N2 are bitcast values on which insert_subvector
+ // would makes sense, pull the bitcast through.
+ // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
+ // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
+ if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
+ SDValue CN0 = N0.getOperand(0);
+ SDValue CN1 = N1.getOperand(0);
+ EVT CN0VT = CN0.getValueType();
+ EVT CN1VT = CN1.getValueType();
+ if (CN0VT.isVector() && CN1VT.isVector() &&
+ CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
+ CN0VT.getVectorNumElements() == VT.getVectorNumElements()) {
+ SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
+ CN0.getValueType(), CN0, CN1, N2);
+ return DAG.getBitcast(VT, NewINSERT);
+ }
+ }
+
+ // Combine INSERT_SUBVECTORs where we are inserting to the same index.
+ // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
+ // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
+ if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
+ N0.getOperand(1).getValueType() == N1.getValueType() &&
+ N0.getOperand(2) == N2)
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
+ N1, N2);
+
+ // Eliminate an intermediate insert into an undef vector:
+ // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
+ // insert_subvector undef, X, N2
+ if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
+ N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
+ N1.getOperand(1), N2);
+
+ if (!isa<ConstantSDNode>(N2))
+ return SDValue();
+
+ unsigned InsIdx = cast<ConstantSDNode>(N2)->getZExtValue();
+
+ // Canonicalize insert_subvector dag nodes.
+ // Example:
+ // (insert_subvector (insert_subvector A, Idx0), Idx1)
+ // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
+ if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
+ N1.getValueType() == N0.getOperand(1).getValueType() &&
+ isa<ConstantSDNode>(N0.getOperand(2))) {
+ unsigned OtherIdx = N0.getConstantOperandVal(2);
+ if (InsIdx < OtherIdx) {
+ // Swap nodes.
+ SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
+ N0.getOperand(0), N1, N2);
+ AddToWorklist(NewOp.getNode());
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
+ VT, NewOp, N0.getOperand(1), N0.getOperand(2));
+ }
+ }
+
// If the input vector is a concatenation, and the insert replaces
- // one of the halves, we can optimize into a single concat_vectors.
- if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
- N0->getNumOperands() == 2 && N2.getOpcode() == ISD::Constant) {
- APInt InsIdx = cast<ConstantSDNode>(N2)->getAPIntValue();
- EVT VT = N->getValueType(0);
+ // one of the pieces, we can optimize into a single concat_vectors.
+ if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
+ N0.getOperand(0).getValueType() == N1.getValueType()) {
+ unsigned Factor = N1.getValueType().getVectorNumElements();
- // Lower half: fold (insert_subvector (concat_vectors X, Y), Z) ->
- // (concat_vectors Z, Y)
- if (InsIdx == 0)
- return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
- N->getOperand(1), N0.getOperand(1));
+ SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
+ Ops[cast<ConstantSDNode>(N2)->getZExtValue() / Factor] = N1;
- // Upper half: fold (insert_subvector (concat_vectors X, Y), Z) ->
- // (concat_vectors X, Z)
- if (InsIdx == VT.getVectorNumElements()/2)
- return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
- N0.getOperand(0), N->getOperand(1));
+ return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
}
+ // Simplify source operands based on insertion.
+ if (SimplifyDemandedVectorElts(SDValue(N, 0)))
+ return SDValue(N, 0);
+
return SDValue();
}
@@ -13666,22 +18083,18 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
/// vector_shuffle V, Zero, <0, 4, 2, 4>
SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
+ assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
+
EVT VT = N->getValueType(0);
SDValue LHS = N->getOperand(0);
- SDValue RHS = N->getOperand(1);
- SDLoc dl(N);
+ SDValue RHS = peekThroughBitcasts(N->getOperand(1));
+ SDLoc DL(N);
// Make sure we're not running after operation legalization where it
// may have custom lowered the vector shuffles.
if (LegalOperations)
return SDValue();
- if (N->getOpcode() != ISD::AND)
- return SDValue();
-
- if (RHS.getOpcode() == ISD::BITCAST)
- RHS = RHS.getOperand(0);
-
if (RHS.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();
@@ -13700,7 +18113,7 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
int EltIdx = i / Split;
int SubIdx = i % Split;
SDValue Elt = RHS.getOperand(EltIdx);
- if (Elt.getOpcode() == ISD::UNDEF) {
+ if (Elt.isUndef()) {
Indices.push_back(-1);
continue;
}
@@ -13715,9 +18128,9 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
// Extract the sub element from the constant bit mask.
if (DAG.getDataLayout().isBigEndian()) {
- Bits = Bits.lshr((Split - SubIdx - 1) * NumSubBits);
+ Bits.lshrInPlace((Split - SubIdx - 1) * NumSubBits);
} else {
- Bits = Bits.lshr(SubIdx * NumSubBits);
+ Bits.lshrInPlace(SubIdx * NumSubBits);
}
if (Split > 1)
@@ -13737,10 +18150,10 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
return SDValue();
- SDValue Zero = DAG.getConstant(0, dl, ClearVT);
- return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, dl,
+ SDValue Zero = DAG.getConstant(0, DL, ClearVT);
+ return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
DAG.getBitcast(ClearVT, LHS),
- Zero, &Indices[0]));
+ Zero, Indices));
};
// Determine maximum split level (byte level masking).
@@ -13770,17 +18183,13 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
N->getOpcode(), SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
return Fold;
- // Try to convert a constant mask AND into a shuffle clear mask.
- if (SDValue Shuffle = XformToShuffleWithZero(N))
- return Shuffle;
-
// Type legalization might introduce new shuffles in the DAG.
// Fold (VBinOp (shuffle (A, Undef, Mask)), (shuffle (B, Undef, Mask)))
// -> (shuffle (VBinOp (A, B)), Undef, Mask).
if (LegalTypes && isa<ShuffleVectorSDNode>(LHS) &&
isa<ShuffleVectorSDNode>(RHS) && LHS.hasOneUse() && RHS.hasOneUse() &&
- LHS.getOperand(1).getOpcode() == ISD::UNDEF &&
- RHS.getOperand(1).getOpcode() == ISD::UNDEF) {
+ LHS.getOperand(1).isUndef() &&
+ RHS.getOperand(1).isUndef()) {
ShuffleVectorSDNode *SVN0 = cast<ShuffleVectorSDNode>(LHS);
ShuffleVectorSDNode *SVN1 = cast<ShuffleVectorSDNode>(RHS);
@@ -13792,15 +18201,15 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
N->getFlags());
AddUsersToWorklist(N);
return DAG.getVectorShuffle(VT, SDLoc(N), NewBinOp, UndefVector,
- &SVN0->getMask()[0]);
+ SVN0->getMask());
}
}
return SDValue();
}
-SDValue DAGCombiner::SimplifySelect(SDLoc DL, SDValue N0,
- SDValue N1, SDValue N2){
+SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
+ SDValue N2) {
assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
@@ -13834,34 +18243,33 @@ SDValue DAGCombiner::SimplifySelect(SDLoc DL, SDValue N0,
/// the DAG combiner loop to avoid it being looked at.
bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
SDValue RHS) {
-
- // fold (select (setcc x, -0.0, *lt), NaN, (fsqrt x))
- // The select + setcc is redundant, because fsqrt returns NaN for X < -0.
+ // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
+ // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
// We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
SDValue Sqrt = RHS;
ISD::CondCode CC;
SDValue CmpLHS;
- const ConstantFPSDNode *NegZero = nullptr;
+ const ConstantFPSDNode *Zero = nullptr;
if (TheSelect->getOpcode() == ISD::SELECT_CC) {
- CC = dyn_cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
+ CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
CmpLHS = TheSelect->getOperand(0);
- NegZero = isConstOrConstSplatFP(TheSelect->getOperand(1));
+ Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
} else {
// SELECT or VSELECT
SDValue Cmp = TheSelect->getOperand(0);
if (Cmp.getOpcode() == ISD::SETCC) {
- CC = dyn_cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
+ CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
CmpLHS = Cmp.getOperand(0);
- NegZero = isConstOrConstSplatFP(Cmp.getOperand(1));
+ Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
}
}
- if (NegZero && NegZero->isNegative() && NegZero->isZero() &&
+ if (Zero && Zero->isZero() &&
Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
CC == ISD::SETULT || CC == ISD::SETLT)) {
- // We have: (select (setcc x, -0.0, *lt), NaN, (fsqrt x))
+ // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
CombineTo(TheSelect, Sqrt);
return true;
}
@@ -13909,31 +18317,64 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
LLD->getBasePtr().getValueType()))
return false;
+ // The loads must not depend on one another.
+ if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
+ return false;
+
// Check that the select condition doesn't reach either load. If so,
// folding this will induce a cycle into the DAG. If not, this is safe to
// xform, so create a select of the addresses.
+
+ SmallPtrSet<const SDNode *, 32> Visited;
+ SmallVector<const SDNode *, 16> Worklist;
+
+ // Always fail if LLD and RLD are not independent. TheSelect is a
+ // predecessor to all Nodes in question so we need not search past it.
+
+ Visited.insert(TheSelect);
+ Worklist.push_back(LLD);
+ Worklist.push_back(RLD);
+
+ if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
+ SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
+ return false;
+
SDValue Addr;
if (TheSelect->getOpcode() == ISD::SELECT) {
+ // We cannot do this optimization if any pair of {RLD, LLD} is a
+ // predecessor to {RLD, LLD, CondNode}. As we've already compared the
+ // Loads, we only need to check if CondNode is a successor to one of the
+ // loads. We can further avoid this if there's no use of their chain
+ // value.
SDNode *CondNode = TheSelect->getOperand(0).getNode();
- if ((LLD->hasAnyUseOfValue(1) && LLD->isPredecessorOf(CondNode)) ||
- (RLD->hasAnyUseOfValue(1) && RLD->isPredecessorOf(CondNode)))
- return false;
- // The loads must not depend on one another.
- if (LLD->isPredecessorOf(RLD) ||
- RLD->isPredecessorOf(LLD))
+ Worklist.push_back(CondNode);
+
+ if ((LLD->hasAnyUseOfValue(1) &&
+ SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
+ (RLD->hasAnyUseOfValue(1) &&
+ SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
return false;
+
Addr = DAG.getSelect(SDLoc(TheSelect),
LLD->getBasePtr().getValueType(),
TheSelect->getOperand(0), LLD->getBasePtr(),
RLD->getBasePtr());
} else { // Otherwise SELECT_CC
+ // We cannot do this optimization if any pair of {RLD, LLD} is a
+ // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
+ // the Loads, we only need to check if CondLHS/CondRHS is a successor to
+ // one of the loads. We can further avoid this if there's no use of their
+ // chain value.
+
SDNode *CondLHS = TheSelect->getOperand(0).getNode();
SDNode *CondRHS = TheSelect->getOperand(1).getNode();
+ Worklist.push_back(CondLHS);
+ Worklist.push_back(CondRHS);
if ((LLD->hasAnyUseOfValue(1) &&
- (LLD->isPredecessorOf(CondLHS) || LLD->isPredecessorOf(CondRHS))) ||
+ SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
(RLD->hasAnyUseOfValue(1) &&
- (RLD->isPredecessorOf(CondLHS) || RLD->isPredecessorOf(CondRHS))))
+ SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
return false;
Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
@@ -13948,24 +18389,24 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
// It is safe to replace the two loads if they have different alignments,
// but the new load must be the minimum (most restrictive) alignment of the
// inputs.
- bool isInvariant = LLD->isInvariant() & RLD->isInvariant();
unsigned Alignment = std::min(LLD->getAlignment(), RLD->getAlignment());
+ MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
+ if (!RLD->isInvariant())
+ MMOFlags &= ~MachineMemOperand::MOInvariant;
+ if (!RLD->isDereferenceable())
+ MMOFlags &= ~MachineMemOperand::MODereferenceable;
if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
- Load = DAG.getLoad(TheSelect->getValueType(0),
- SDLoc(TheSelect),
- // FIXME: Discards pointer and AA info.
- LLD->getChain(), Addr, MachinePointerInfo(),
- LLD->isVolatile(), LLD->isNonTemporal(),
- isInvariant, Alignment);
+ // FIXME: Discards pointer and AA info.
+ Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
+ LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
+ MMOFlags);
} else {
- Load = DAG.getExtLoad(LLD->getExtensionType() == ISD::EXTLOAD ?
- RLD->getExtensionType() : LLD->getExtensionType(),
- SDLoc(TheSelect),
- TheSelect->getValueType(0),
- // FIXME: Discards pointer and AA info.
- LLD->getChain(), Addr, MachinePointerInfo(),
- LLD->getMemoryVT(), LLD->isVolatile(),
- LLD->isNonTemporal(), isInvariant, Alignment);
+ // FIXME: Discards pointer and AA info.
+ Load = DAG.getExtLoad(
+ LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
+ : LLD->getExtensionType(),
+ SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
+ MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
}
// Users of the select now use the result of the load.
@@ -13981,144 +18422,161 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
return false;
}
+/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
+/// bitwise 'and'.
+SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
+ SDValue N1, SDValue N2, SDValue N3,
+ ISD::CondCode CC) {
+ // If this is a select where the false operand is zero and the compare is a
+ // check of the sign bit, see if we can perform the "gzip trick":
+ // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
+ // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
+ EVT XType = N0.getValueType();
+ EVT AType = N2.getValueType();
+ if (!isNullConstant(N3) || !XType.bitsGE(AType))
+ return SDValue();
+
+ // If the comparison is testing for a positive value, we have to invert
+ // the sign bit mask, so only do that transform if the target has a bitwise
+ // 'and not' instruction (the invert is free).
+ if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
+ // (X > -1) ? A : 0
+ // (X > 0) ? X : 0 <-- This is canonical signed max.
+ if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
+ return SDValue();
+ } else if (CC == ISD::SETLT) {
+ // (X < 0) ? A : 0
+ // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
+ if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
+ return SDValue();
+ } else {
+ return SDValue();
+ }
+
+ // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
+ // constant.
+ EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
+ auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
+ if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
+ unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
+ SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+ SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
+ AddToWorklist(Shift.getNode());
+
+ if (XType.bitsGT(AType)) {
+ Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
+ AddToWorklist(Shift.getNode());
+ }
+
+ if (CC == ISD::SETGT)
+ Shift = DAG.getNOT(DL, Shift, AType);
+
+ return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
+ }
+
+ SDValue ShiftAmt = DAG.getConstant(XType.getSizeInBits() - 1, DL, ShiftAmtTy);
+ SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
+ AddToWorklist(Shift.getNode());
+
+ if (XType.bitsGT(AType)) {
+ Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
+ AddToWorklist(Shift.getNode());
+ }
+
+ if (CC == ISD::SETGT)
+ Shift = DAG.getNOT(DL, Shift, AType);
+
+ return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
+}
+
+/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
+/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
+/// in it. This may be a win when the constant is not otherwise available
+/// because it replaces two constant pool loads with one.
+SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
+ const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
+ ISD::CondCode CC) {
+ if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType().isFloatingPoint()))
+ return SDValue();
+
+ // If we are before legalize types, we want the other legalization to happen
+ // first (for example, to avoid messing with soft float).
+ auto *TV = dyn_cast<ConstantFPSDNode>(N2);
+ auto *FV = dyn_cast<ConstantFPSDNode>(N3);
+ EVT VT = N2.getValueType();
+ if (!TV || !FV || !TLI.isTypeLegal(VT))
+ return SDValue();
+
+ // If a constant can be materialized without loads, this does not make sense.
+ if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
+ TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0)) ||
+ TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0)))
+ return SDValue();
+
+ // If both constants have multiple uses, then we won't need to do an extra
+ // load. The values are likely around in registers for other users.
+ if (!TV->hasOneUse() && !FV->hasOneUse())
+ return SDValue();
+
+ Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
+ const_cast<ConstantFP*>(TV->getConstantFPValue()) };
+ Type *FPTy = Elts[0]->getType();
+ const DataLayout &TD = DAG.getDataLayout();
+
+ // Create a ConstantArray of the two constants.
+ Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
+ SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
+ TD.getPrefTypeAlignment(FPTy));
+ unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment();
+
+ // Get offsets to the 0 and 1 elements of the array, so we can select between
+ // them.
+ SDValue Zero = DAG.getIntPtrConstant(0, DL);
+ unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
+ SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
+ SDValue Cond =
+ DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
+ AddToWorklist(Cond.getNode());
+ SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
+ AddToWorklist(CstOffset.getNode());
+ CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
+ AddToWorklist(CPIdx.getNode());
+ return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
+ MachinePointerInfo::getConstantPool(
+ DAG.getMachineFunction()), Alignment);
+}
+
/// Simplify an expression of the form (N0 cond N1) ? N2 : N3
/// where 'cond' is the comparison specified by CC.
-SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
- SDValue N2, SDValue N3,
- ISD::CondCode CC, bool NotExtCompare) {
+SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
+ SDValue N2, SDValue N3, ISD::CondCode CC,
+ bool NotExtCompare) {
// (x ? y : y) -> y.
if (N2 == N3) return N2;
+ EVT CmpOpVT = N0.getValueType();
EVT VT = N2.getValueType();
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
- ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
+ auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
+ auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
+ auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
- // Determine if the condition we're dealing with is constant
- SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()),
- N0, N1, CC, DL, false);
+ // Determine if the condition we're dealing with is constant.
+ SDValue SCC = SimplifySetCC(getSetCCResultType(CmpOpVT), N0, N1, CC, DL,
+ false);
if (SCC.getNode()) AddToWorklist(SCC.getNode());
- if (ConstantSDNode *SCCC = dyn_cast_or_null<ConstantSDNode>(SCC.getNode())) {
+ if (auto *SCCC = dyn_cast_or_null<ConstantSDNode>(SCC.getNode())) {
// fold select_cc true, x, y -> x
// fold select_cc false, x, y -> y
return !SCCC->isNullValue() ? N2 : N3;
}
- // Check to see if we can simplify the select into an fabs node
- if (ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N1)) {
- // Allow either -0.0 or 0.0
- if (CFP->isZero()) {
- // select (setg[te] X, +/-0.0), X, fneg(X) -> fabs
- if ((CC == ISD::SETGE || CC == ISD::SETGT) &&
- N0 == N2 && N3.getOpcode() == ISD::FNEG &&
- N2 == N3.getOperand(0))
- return DAG.getNode(ISD::FABS, DL, VT, N0);
-
- // select (setl[te] X, +/-0.0), fneg(X), X -> fabs
- if ((CC == ISD::SETLT || CC == ISD::SETLE) &&
- N0 == N3 && N2.getOpcode() == ISD::FNEG &&
- N2.getOperand(0) == N3)
- return DAG.getNode(ISD::FABS, DL, VT, N3);
- }
- }
-
- // Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
- // where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
- // in it. This is a win when the constant is not otherwise available because
- // it replaces two constant pool loads with one. We only do this if the FP
- // type is known to be legal, because if it isn't, then we are before legalize
- // types an we want the other legalization to happen first (e.g. to avoid
- // messing with soft float) and if the ConstantFP is not legal, because if
- // it is legal, we may not need to store the FP constant in a constant pool.
- if (ConstantFPSDNode *TV = dyn_cast<ConstantFPSDNode>(N2))
- if (ConstantFPSDNode *FV = dyn_cast<ConstantFPSDNode>(N3)) {
- if (TLI.isTypeLegal(N2.getValueType()) &&
- (TLI.getOperationAction(ISD::ConstantFP, N2.getValueType()) !=
- TargetLowering::Legal &&
- !TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0)) &&
- !TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0))) &&
- // If both constants have multiple uses, then we won't need to do an
- // extra load, they are likely around in registers for other users.
- (TV->hasOneUse() || FV->hasOneUse())) {
- Constant *Elts[] = {
- const_cast<ConstantFP*>(FV->getConstantFPValue()),
- const_cast<ConstantFP*>(TV->getConstantFPValue())
- };
- Type *FPTy = Elts[0]->getType();
- const DataLayout &TD = DAG.getDataLayout();
-
- // Create a ConstantArray of the two constants.
- Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
- SDValue CPIdx =
- DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
- TD.getPrefTypeAlignment(FPTy));
- unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment();
-
- // Get the offsets to the 0 and 1 element of the array so that we can
- // select between them.
- SDValue Zero = DAG.getIntPtrConstant(0, DL);
- unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
- SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
-
- SDValue Cond = DAG.getSetCC(DL,
- getSetCCResultType(N0.getValueType()),
- N0, N1, CC);
- AddToWorklist(Cond.getNode());
- SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(),
- Cond, One, Zero);
- AddToWorklist(CstOffset.getNode());
- CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx,
- CstOffset);
- AddToWorklist(CPIdx.getNode());
- return DAG.getLoad(
- TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
- MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
- false, false, false, Alignment);
- }
- }
-
- // Check to see if we can perform the "gzip trick", transforming
- // (select_cc setlt X, 0, A, 0) -> (and (sra X, (sub size(X), 1), A)
- if (isNullConstant(N3) && CC == ISD::SETLT &&
- (isNullConstant(N1) || // (a < 0) ? b : 0
- (isOneConstant(N1) && N0 == N2))) { // (a < 1) ? a : 0
- EVT XType = N0.getValueType();
- EVT AType = N2.getValueType();
- if (XType.bitsGE(AType)) {
- // and (sra X, size(X)-1, A) -> "and (srl X, C2), A" iff A is a
- // single-bit constant.
- if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
- unsigned ShCtV = N2C->getAPIntValue().logBase2();
- ShCtV = XType.getSizeInBits() - ShCtV - 1;
- SDValue ShCt = DAG.getConstant(ShCtV, SDLoc(N0),
- getShiftAmountTy(N0.getValueType()));
- SDValue Shift = DAG.getNode(ISD::SRL, SDLoc(N0),
- XType, N0, ShCt);
- AddToWorklist(Shift.getNode());
-
- if (XType.bitsGT(AType)) {
- Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
- AddToWorklist(Shift.getNode());
- }
-
- return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
- }
-
- SDValue Shift = DAG.getNode(ISD::SRA, SDLoc(N0),
- XType, N0,
- DAG.getConstant(XType.getSizeInBits() - 1,
- SDLoc(N0),
- getShiftAmountTy(N0.getValueType())));
- AddToWorklist(Shift.getNode());
-
- if (XType.bitsGT(AType)) {
- Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
- AddToWorklist(Shift.getNode());
- }
+ if (SDValue V =
+ convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
+ return V;
- return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
- }
- }
+ if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
+ return V;
// fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
// where y is has a single bit set.
@@ -14129,10 +18587,10 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
SDValue AndLHS = N0->getOperand(0);
- ConstantSDNode *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
+ auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
// Shift the tested bit over the sign bit.
- APInt AndMask = ConstAndRHS->getAPIntValue();
+ const APInt &AndMask = ConstAndRHS->getAPIntValue();
SDValue ShlAmt =
DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
getShiftAmountTy(AndLHS.getValueType()));
@@ -14150,48 +18608,48 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
}
// fold select C, 16, 0 -> shl C, 4
- if (N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2() &&
- TLI.getBooleanContents(N0.getValueType()) ==
- TargetLowering::ZeroOrOneBooleanContent) {
+ bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
+ bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
+
+ if ((Fold || Swap) &&
+ TLI.getBooleanContents(CmpOpVT) ==
+ TargetLowering::ZeroOrOneBooleanContent &&
+ (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
+
+ if (Swap) {
+ CC = ISD::getSetCCInverse(CC, CmpOpVT.isInteger());
+ std::swap(N2C, N3C);
+ }
// If the caller doesn't want us to simplify this into a zext of a compare,
// don't do it.
if (NotExtCompare && N2C->isOne())
return SDValue();
- // Get a SetCC of the condition
- // NOTE: Don't create a SETCC if it's not legal on this target.
- if (!LegalOperations ||
- TLI.isOperationLegal(ISD::SETCC, N0.getValueType())) {
- SDValue Temp, SCC;
- // cast from setcc result type to select result type
- if (LegalTypes) {
- SCC = DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()),
- N0, N1, CC);
- if (N2.getValueType().bitsLT(SCC.getValueType()))
- Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2),
- N2.getValueType());
- else
- Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2),
- N2.getValueType(), SCC);
- } else {
- SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
- Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2),
- N2.getValueType(), SCC);
- }
+ SDValue Temp, SCC;
+ // zext (setcc n0, n1)
+ if (LegalTypes) {
+ SCC = DAG.getSetCC(DL, getSetCCResultType(CmpOpVT), N0, N1, CC);
+ if (VT.bitsLT(SCC.getValueType()))
+ Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
+ else
+ Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
+ } else {
+ SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
+ Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
+ }
- AddToWorklist(SCC.getNode());
- AddToWorklist(Temp.getNode());
+ AddToWorklist(SCC.getNode());
+ AddToWorklist(Temp.getNode());
- if (N2C->isOne())
- return Temp;
+ if (N2C->isOne())
+ return Temp;
- // shl setcc result by log2 n2c
- return DAG.getNode(
- ISD::SHL, DL, N2.getValueType(), Temp,
- DAG.getConstant(N2C->getAPIntValue().logBase2(), SDLoc(Temp),
- getShiftAmountTy(Temp.getValueType())));
- }
+ // shl setcc result by log2 n2c
+ return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
+ DAG.getConstant(N2C->getAPIntValue().logBase2(),
+ SDLoc(Temp),
+ getShiftAmountTy(Temp.getValueType())));
}
// Check to see if this is an integer abs.
@@ -14211,18 +18669,51 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
N0 == N3 && N2.getOpcode() == ISD::SUB && N0 == N2.getOperand(1))
SubC = dyn_cast<ConstantSDNode>(N2.getOperand(0));
- EVT XType = N0.getValueType();
- if (SubC && SubC->isNullValue() && XType.isInteger()) {
+ if (SubC && SubC->isNullValue() && CmpOpVT.isInteger()) {
SDLoc DL(N0);
- SDValue Shift = DAG.getNode(ISD::SRA, DL, XType,
- N0,
- DAG.getConstant(XType.getSizeInBits() - 1, DL,
- getShiftAmountTy(N0.getValueType())));
- SDValue Add = DAG.getNode(ISD::ADD, DL,
- XType, N0, Shift);
+ SDValue Shift = DAG.getNode(ISD::SRA, DL, CmpOpVT, N0,
+ DAG.getConstant(CmpOpVT.getSizeInBits() - 1,
+ DL,
+ getShiftAmountTy(CmpOpVT)));
+ SDValue Add = DAG.getNode(ISD::ADD, DL, CmpOpVT, N0, Shift);
AddToWorklist(Shift.getNode());
AddToWorklist(Add.getNode());
- return DAG.getNode(ISD::XOR, DL, XType, Add, Shift);
+ return DAG.getNode(ISD::XOR, DL, CmpOpVT, Add, Shift);
+ }
+ }
+
+ // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
+ // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
+ // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
+ // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
+ // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
+ // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
+ // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
+ // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
+ if (N1C && N1C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+ SDValue ValueOnZero = N2;
+ SDValue Count = N3;
+ // If the condition is NE instead of E, swap the operands.
+ if (CC == ISD::SETNE)
+ std::swap(ValueOnZero, Count);
+ // Check if the value on zero is a constant equal to the bits in the type.
+ if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
+ if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
+ // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
+ // legal, combine to just cttz.
+ if ((Count.getOpcode() == ISD::CTTZ ||
+ Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
+ N0 == Count.getOperand(0) &&
+ (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
+ return DAG.getNode(ISD::CTTZ, DL, VT, N0);
+ // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
+ // legal, combine to just ctlz.
+ if ((Count.getOpcode() == ISD::CTLZ ||
+ Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
+ N0 == Count.getOperand(0) &&
+ (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
+ return DAG.getNode(ISD::CTLZ, DL, VT, N0);
+ }
}
}
@@ -14230,9 +18721,9 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
}
/// This is a stub for TargetLowering::SimplifySetCC.
-SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
- SDValue N1, ISD::CondCode Cond,
- SDLoc DL, bool foldBooleans) {
+SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
+ ISD::CondCode Cond, const SDLoc &DL,
+ bool foldBooleans) {
TargetLowering::DAGCombinerInfo
DagCombineInfo(DAG, Level, false, this);
return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
@@ -14243,21 +18734,19 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
/// by a magic number.
/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
SDValue DAGCombiner::BuildSDIV(SDNode *N) {
- ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
- if (!C)
- return SDValue();
-
- // Avoid division by zero.
- if (C->isNullValue())
+ // when optimising for minimum size, we don't want to expand a div to a mul
+ // and a shift.
+ if (DAG.getMachineFunction().getFunction().optForMinSize())
return SDValue();
- std::vector<SDNode*> Built;
- SDValue S =
- TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
+ SmallVector<SDNode *, 8> Built;
+ if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
+ for (SDNode *N : Built)
+ AddToWorklist(N);
+ return S;
+ }
- for (SDNode *N : Built)
- AddToWorklist(N);
- return S;
+ return SDValue();
}
/// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
@@ -14271,12 +18760,14 @@ SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
if (C->isNullValue())
return SDValue();
- std::vector<SDNode *> Built;
- SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, &Built);
+ SmallVector<SDNode *, 8> Built;
+ if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
+ for (SDNode *N : Built)
+ AddToWorklist(N);
+ return S;
+ }
- for (SDNode *N : Built)
- AddToWorklist(N);
- return S;
+ return SDValue();
}
/// Given an ISD::UDIV node expressing a divide by constant, return a DAG
@@ -14284,47 +18775,66 @@ SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
/// number.
/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
SDValue DAGCombiner::BuildUDIV(SDNode *N) {
- ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
- if (!C)
+ // when optimising for minimum size, we don't want to expand a div to a mul
+ // and a shift.
+ if (DAG.getMachineFunction().getFunction().optForMinSize())
return SDValue();
- // Avoid division by zero.
- if (C->isNullValue())
- return SDValue();
+ SmallVector<SDNode *, 8> Built;
+ if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
+ for (SDNode *N : Built)
+ AddToWorklist(N);
+ return S;
+ }
- std::vector<SDNode*> Built;
- SDValue S =
- TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
+ return SDValue();
+}
- for (SDNode *N : Built)
- AddToWorklist(N);
- return S;
+/// Determines the LogBase2 value for a non-null input value using the
+/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
+SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
+ EVT VT = V.getValueType();
+ unsigned EltBits = VT.getScalarSizeInBits();
+ SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
+ SDValue Base = DAG.getConstant(EltBits - 1, DL, VT);
+ SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
+ return LogBase2;
}
-SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
+/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
+/// For the reciprocal, we need to find the zero of the function:
+/// F(X) = A X - 1 [which has a zero at X = 1/A]
+/// =>
+/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
+/// does not require additional intermediate precision]
+SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags) {
if (Level >= AfterLegalizeDAG)
return SDValue();
- // Expose the DAG combiner to the target combiner implementations.
- TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this);
+ // TODO: Handle half and/or extended types?
+ EVT VT = Op.getValueType();
+ if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
+ return SDValue();
+
+ // If estimates are explicitly disabled for this function, we're done.
+ MachineFunction &MF = DAG.getMachineFunction();
+ int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
+ if (Enabled == TLI.ReciprocalEstimate::Disabled)
+ return SDValue();
+
+ // Estimates may be explicitly enabled for this type with a custom number of
+ // refinement steps.
+ int Iterations = TLI.getDivRefinementSteps(VT, MF);
+ if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
+ AddToWorklist(Est.getNode());
- unsigned Iterations = 0;
- if (SDValue Est = TLI.getRecipEstimate(Op, DCI, Iterations)) {
if (Iterations) {
- // Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
- // For the reciprocal, we need to find the zero of the function:
- // F(X) = A X - 1 [which has a zero at X = 1/A]
- // =>
- // X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
- // does not require additional intermediate precision]
EVT VT = Op.getValueType();
SDLoc DL(Op);
SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
- AddToWorklist(Est.getNode());
-
// Newton iterations: Est = Est + Est (1 - Arg * Est)
- for (unsigned i = 0; i < Iterations; ++i) {
+ for (int i = 0; i < Iterations; ++i) {
SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est, Flags);
AddToWorklist(NewEst.getNode());
@@ -14350,9 +18860,9 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
/// =>
/// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
/// As a result, we precompute A/2 prior to the iteration loop.
-SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est,
- unsigned Iterations,
- SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
+ unsigned Iterations,
+ SDNodeFlags Flags, bool Reciprocal) {
EVT VT = Arg.getValueType();
SDLoc DL(Arg);
SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
@@ -14379,6 +18889,13 @@ SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est,
Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
AddToWorklist(Est.getNode());
}
+
+ // If non-reciprocal square root is requested, multiply the result by Arg.
+ if (!Reciprocal) {
+ Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
+ AddToWorklist(Est.getNode());
+ }
+
return Est;
}
@@ -14387,48 +18904,114 @@ SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est,
/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
/// =>
/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
-SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est,
- unsigned Iterations,
- SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
+ unsigned Iterations,
+ SDNodeFlags Flags, bool Reciprocal) {
EVT VT = Arg.getValueType();
SDLoc DL(Arg);
SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
- // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est)
- for (unsigned i = 0; i < Iterations; ++i) {
- SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
- AddToWorklist(HalfEst.getNode());
-
- Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
- AddToWorklist(Est.getNode());
-
- Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
- AddToWorklist(Est.getNode());
+ // This routine must enter the loop below to work correctly
+ // when (Reciprocal == false).
+ assert(Iterations > 0);
- Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags);
- AddToWorklist(Est.getNode());
+ // Newton iterations for reciprocal square root:
+ // E = (E * -0.5) * ((A * E) * E + -3.0)
+ for (unsigned i = 0; i < Iterations; ++i) {
+ SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
+ AddToWorklist(AE.getNode());
+
+ SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
+ AddToWorklist(AEE.getNode());
+
+ SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
+ AddToWorklist(RHS.getNode());
+
+ // When calculating a square root at the last iteration build:
+ // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
+ // (notice a common subexpression)
+ SDValue LHS;
+ if (Reciprocal || (i + 1) < Iterations) {
+ // RSQRT: LHS = (E * -0.5)
+ LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
+ } else {
+ // SQRT: LHS = (A * E) * -0.5
+ LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
+ }
+ AddToWorklist(LHS.getNode());
- Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags);
+ Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
AddToWorklist(Est.getNode());
}
+
return Est;
}
-SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
+/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
+/// Op can be zero.
+SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
+ bool Reciprocal) {
if (Level >= AfterLegalizeDAG)
return SDValue();
- // Expose the DAG combiner to the target combiner implementations.
- TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this);
- unsigned Iterations = 0;
+ // TODO: Handle half and/or extended types?
+ EVT VT = Op.getValueType();
+ if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64)
+ return SDValue();
+
+ // If estimates are explicitly disabled for this function, we're done.
+ MachineFunction &MF = DAG.getMachineFunction();
+ int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
+ if (Enabled == TLI.ReciprocalEstimate::Disabled)
+ return SDValue();
+
+ // Estimates may be explicitly enabled for this type with a custom number of
+ // refinement steps.
+ int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
+
bool UseOneConstNR = false;
- if (SDValue Est = TLI.getRsqrtEstimate(Op, DCI, Iterations, UseOneConstNR)) {
+ if (SDValue Est =
+ TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
+ Reciprocal)) {
AddToWorklist(Est.getNode());
+
if (Iterations) {
- Est = UseOneConstNR ?
- BuildRsqrtNROneConst(Op, Est, Iterations, Flags) :
- BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags);
+ Est = UseOneConstNR
+ ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
+ : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
+
+ if (!Reciprocal) {
+ // The estimate is now completely wrong if the input was exactly 0.0 or
+ // possibly a denormal. Force the answer to 0.0 for those cases.
+ EVT VT = Op.getValueType();
+ SDLoc DL(Op);
+ EVT CCVT = getSetCCResultType(VT);
+ ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT;
+ const Function &F = DAG.getMachineFunction().getFunction();
+ Attribute Denorms = F.getFnAttribute("denormal-fp-math");
+ if (Denorms.getValueAsString().equals("ieee")) {
+ // fabs(X) < SmallestNormal ? 0.0 : Est
+ const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
+ APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
+ SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
+ SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+ SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
+ SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
+ Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est);
+ AddToWorklist(Fabs.getNode());
+ AddToWorklist(IsDenorm.getNode());
+ AddToWorklist(Est.getNode());
+ } else {
+ // X == 0.0 ? 0.0 : Est
+ SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+ SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+ Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est);
+ AddToWorklist(IsZero.getNode());
+ AddToWorklist(Est.getNode());
+ }
+ }
}
return Est;
}
@@ -14436,41 +19019,12 @@ SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
return SDValue();
}
-/// Return true if base is a frame index, which is known not to alias with
-/// anything but itself. Provides base object and offset as results.
-static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
- const GlobalValue *&GV, const void *&CV) {
- // Assume it is a primitive operation.
- Base = Ptr; Offset = 0; GV = nullptr; CV = nullptr;
-
- // If it's an adding a simple constant then integrate the offset.
- if (Base.getOpcode() == ISD::ADD) {
- if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Base.getOperand(1))) {
- Base = Base.getOperand(0);
- Offset += C->getZExtValue();
- }
- }
-
- // Return the underlying GlobalValue, and update the Offset. Return false
- // for GlobalAddressSDNode since the same GlobalAddress may be represented
- // by multiple nodes with different offsets.
- if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Base)) {
- GV = G->getGlobal();
- Offset += G->getOffset();
- return false;
- }
+SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
+ return buildSqrtEstimateImpl(Op, Flags, true);
+}
- // Return the underlying Constant value, and update the Offset. Return false
- // for ConstantSDNodes since the same constant pool entry may be represented
- // by multiple nodes with different offsets.
- if (ConstantPoolSDNode *C = dyn_cast<ConstantPoolSDNode>(Base)) {
- CV = C->isMachineConstantPoolEntry() ? (const void *)C->getMachineCPVal()
- : (const void *)C->getConstVal();
- Offset += C->getOffset();
- return false;
- }
- // If it's any of the following then it can't alias with anything but itself.
- return isa<FrameIndexSDNode>(Base);
+SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
+ return buildSqrtEstimateImpl(Op, Flags, false);
}
/// Return true if there is any possibility that the two addresses overlap.
@@ -14490,54 +19044,63 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const {
if (Op1->isInvariant() && Op0->writeMem())
return false;
- // Gather base node and offset information.
- SDValue Base1, Base2;
- int64_t Offset1, Offset2;
- const GlobalValue *GV1, *GV2;
- const void *CV1, *CV2;
- bool isFrameIndex1 = FindBaseOffset(Op0->getBasePtr(),
- Base1, Offset1, GV1, CV1);
- bool isFrameIndex2 = FindBaseOffset(Op1->getBasePtr(),
- Base2, Offset2, GV2, CV2);
-
- // If they have a same base address then check to see if they overlap.
- if (Base1 == Base2 || (GV1 && (GV1 == GV2)) || (CV1 && (CV1 == CV2)))
- return !((Offset1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= Offset2 ||
- (Offset2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= Offset1);
-
- // It is possible for different frame indices to alias each other, mostly
- // when tail call optimization reuses return address slots for arguments.
- // To catch this case, look up the actual index of frame indices to compute
- // the real alias relationship.
- if (isFrameIndex1 && isFrameIndex2) {
- MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo();
- Offset1 += MFI->getObjectOffset(cast<FrameIndexSDNode>(Base1)->getIndex());
- Offset2 += MFI->getObjectOffset(cast<FrameIndexSDNode>(Base2)->getIndex());
- return !((Offset1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= Offset2 ||
- (Offset2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= Offset1);
- }
-
- // Otherwise, if we know what the bases are, and they aren't identical, then
- // we know they cannot alias.
- if ((isFrameIndex1 || CV1 || GV1) && (isFrameIndex2 || CV2 || GV2))
- return false;
+ unsigned NumBytes0 = Op0->getMemoryVT().getStoreSize();
+ unsigned NumBytes1 = Op1->getMemoryVT().getStoreSize();
+
+ // Check for BaseIndexOffset matching.
+ BaseIndexOffset BasePtr0 = BaseIndexOffset::match(Op0, DAG);
+ BaseIndexOffset BasePtr1 = BaseIndexOffset::match(Op1, DAG);
+ int64_t PtrDiff;
+ if (BasePtr0.getBase().getNode() && BasePtr1.getBase().getNode()) {
+ if (BasePtr0.equalBaseIndex(BasePtr1, DAG, PtrDiff))
+ return !((NumBytes0 <= PtrDiff) || (PtrDiff + NumBytes1 <= 0));
+
+ // If both BasePtr0 and BasePtr1 are FrameIndexes, we will not be
+ // able to calculate their relative offset if at least one arises
+ // from an alloca. However, these allocas cannot overlap and we
+ // can infer there is no alias.
+ if (auto *A = dyn_cast<FrameIndexSDNode>(BasePtr0.getBase()))
+ if (auto *B = dyn_cast<FrameIndexSDNode>(BasePtr1.getBase())) {
+ MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
+ // If the base are the same frame index but the we couldn't find a
+ // constant offset, (indices are different) be conservative.
+ if (A != B && (!MFI.isFixedObjectIndex(A->getIndex()) ||
+ !MFI.isFixedObjectIndex(B->getIndex())))
+ return false;
+ }
- // If we know required SrcValue1 and SrcValue2 have relatively large alignment
- // compared to the size and offset of the access, we may be able to prove they
- // do not alias. This check is conservative for now to catch cases created by
- // splitting vector types.
- if ((Op0->getOriginalAlignment() == Op1->getOriginalAlignment()) &&
- (Op0->getSrcValueOffset() != Op1->getSrcValueOffset()) &&
- (Op0->getMemoryVT().getSizeInBits() >> 3 ==
- Op1->getMemoryVT().getSizeInBits() >> 3) &&
- (Op0->getOriginalAlignment() > Op0->getMemoryVT().getSizeInBits()) >> 3) {
- int64_t OffAlign1 = Op0->getSrcValueOffset() % Op0->getOriginalAlignment();
- int64_t OffAlign2 = Op1->getSrcValueOffset() % Op1->getOriginalAlignment();
-
- // There is no overlap between these relatively aligned accesses of similar
- // size, return no alias.
- if ((OffAlign1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= OffAlign2 ||
- (OffAlign2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= OffAlign1)
+ bool IsFI0 = isa<FrameIndexSDNode>(BasePtr0.getBase());
+ bool IsFI1 = isa<FrameIndexSDNode>(BasePtr1.getBase());
+ bool IsGV0 = isa<GlobalAddressSDNode>(BasePtr0.getBase());
+ bool IsGV1 = isa<GlobalAddressSDNode>(BasePtr1.getBase());
+ bool IsCV0 = isa<ConstantPoolSDNode>(BasePtr0.getBase());
+ bool IsCV1 = isa<ConstantPoolSDNode>(BasePtr1.getBase());
+
+ // If of mismatched base types or checkable indices we can check
+ // they do not alias.
+ if ((BasePtr0.getIndex() == BasePtr1.getIndex() || (IsFI0 != IsFI1) ||
+ (IsGV0 != IsGV1) || (IsCV0 != IsCV1)) &&
+ (IsFI0 || IsGV0 || IsCV0) && (IsFI1 || IsGV1 || IsCV1))
+ return false;
+ }
+
+ // If we know required SrcValue1 and SrcValue2 have relatively large
+ // alignment compared to the size and offset of the access, we may be able
+ // to prove they do not alias. This check is conservative for now to catch
+ // cases created by splitting vector types.
+ int64_t SrcValOffset0 = Op0->getSrcValueOffset();
+ int64_t SrcValOffset1 = Op1->getSrcValueOffset();
+ unsigned OrigAlignment0 = Op0->getOriginalAlignment();
+ unsigned OrigAlignment1 = Op1->getOriginalAlignment();
+ if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
+ NumBytes0 == NumBytes1 && OrigAlignment0 > NumBytes0) {
+ int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0;
+ int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1;
+
+ // There is no overlap between these relatively aligned accesses of
+ // similar size. Return no alias.
+ if ((OffAlign0 + NumBytes0) <= OffAlign1 ||
+ (OffAlign1 + NumBytes1) <= OffAlign0)
return false;
}
@@ -14549,20 +19112,18 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const {
CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
UseAA = false;
#endif
- if (UseAA &&
+
+ if (UseAA && AA &&
Op0->getMemOperand()->getValue() && Op1->getMemOperand()->getValue()) {
// Use alias analysis information.
- int64_t MinOffset = std::min(Op0->getSrcValueOffset(),
- Op1->getSrcValueOffset());
- int64_t Overlap1 = (Op0->getMemoryVT().getSizeInBits() >> 3) +
- Op0->getSrcValueOffset() - MinOffset;
- int64_t Overlap2 = (Op1->getMemoryVT().getSizeInBits() >> 3) +
- Op1->getSrcValueOffset() - MinOffset;
+ int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
+ int64_t Overlap0 = NumBytes0 + SrcValOffset0 - MinOffset;
+ int64_t Overlap1 = NumBytes1 + SrcValOffset1 - MinOffset;
AliasResult AAResult =
- AA.alias(MemoryLocation(Op0->getMemOperand()->getValue(), Overlap1,
- UseTBAA ? Op0->getAAInfo() : AAMDNodes()),
- MemoryLocation(Op1->getMemOperand()->getValue(), Overlap2,
- UseTBAA ? Op1->getAAInfo() : AAMDNodes()));
+ AA->alias(MemoryLocation(Op0->getMemOperand()->getValue(), Overlap0,
+ UseTBAA ? Op0->getAAInfo() : AAMDNodes()),
+ MemoryLocation(Op1->getMemOperand()->getValue(), Overlap1,
+ UseTBAA ? Op1->getAAInfo() : AAMDNodes()) );
if (AAResult == NoAlias)
return false;
}
@@ -14644,75 +19205,28 @@ void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
++Depth;
break;
+ case ISD::CopyFromReg:
+ // Forward past CopyFromReg.
+ Chains.push_back(Chain.getOperand(0));
+ ++Depth;
+ break;
+
default:
// For all other instructions we will just have to take what we can get.
Aliases.push_back(Chain);
break;
}
}
-
- // We need to be careful here to also search for aliases through the
- // value operand of a store, etc. Consider the following situation:
- // Token1 = ...
- // L1 = load Token1, %52
- // S1 = store Token1, L1, %51
- // L2 = load Token1, %52+8
- // S2 = store Token1, L2, %51+8
- // Token2 = Token(S1, S2)
- // L3 = load Token2, %53
- // S3 = store Token2, L3, %52
- // L4 = load Token2, %53+8
- // S4 = store Token2, L4, %52+8
- // If we search for aliases of S3 (which loads address %52), and we look
- // only through the chain, then we'll miss the trivial dependence on L1
- // (which also loads from %52). We then might change all loads and
- // stores to use Token1 as their chain operand, which could result in
- // copying %53 into %52 before copying %52 into %51 (which should
- // happen first).
- //
- // The problem is, however, that searching for such data dependencies
- // can become expensive, and the cost is not directly related to the
- // chain depth. Instead, we'll rule out such configurations here by
- // insisting that we've visited all chain users (except for users
- // of the original chain, which is not necessary). When doing this,
- // we need to look through nodes we don't care about (otherwise, things
- // like register copies will interfere with trivial cases).
-
- SmallVector<const SDNode *, 16> Worklist;
- for (const SDNode *N : Visited)
- if (N != OriginalChain.getNode())
- Worklist.push_back(N);
-
- while (!Worklist.empty()) {
- const SDNode *M = Worklist.pop_back_val();
-
- // We have already visited M, and want to make sure we've visited any uses
- // of M that we care about. For uses that we've not visisted, and don't
- // care about, queue them to the worklist.
-
- for (SDNode::use_iterator UI = M->use_begin(),
- UIE = M->use_end(); UI != UIE; ++UI)
- if (UI.getUse().getValueType() == MVT::Other &&
- Visited.insert(*UI).second) {
- if (isa<MemSDNode>(*UI)) {
- // We've not visited this use, and we care about it (it could have an
- // ordering dependency with the original node).
- Aliases.clear();
- Aliases.push_back(OriginalChain);
- return;
- }
-
- // We've not visited this use, but we don't care about it. Mark it as
- // visited and enqueue it to the worklist.
- Worklist.push_back(*UI);
- }
- }
}
/// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
/// (aliasing node.)
SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
- SmallVector<SDValue, 8> Aliases; // Ops for replacing token factor.
+ if (OptLevel == CodeGenOpt::None)
+ return OldChain;
+
+ // Ops for replacing token factor.
+ SmallVector<SDValue, 8> Aliases;
// Accumulate all the aliases to this node.
GatherAllAliases(N, OldChain, Aliases);
@@ -14729,85 +19243,160 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases);
}
-bool DAGCombiner::findBetterNeighborChains(StoreSDNode* St) {
+// TODO: Replace with with std::monostate when we move to C++17.
+struct UnitT { } Unit;
+bool operator==(const UnitT &, const UnitT &) { return true; }
+bool operator!=(const UnitT &, const UnitT &) { return false; }
+
+// This function tries to collect a bunch of potentially interesting
+// nodes to improve the chains of, all at once. This might seem
+// redundant, as this function gets called when visiting every store
+// node, so why not let the work be done on each store as it's visited?
+//
+// I believe this is mainly important because MergeConsecutiveStores
+// is unable to deal with merging stores of different sizes, so unless
+// we improve the chains of all the potential candidates up-front
+// before running MergeConsecutiveStores, it might only see some of
+// the nodes that will eventually be candidates, and then not be able
+// to go from a partially-merged state to the desired final
+// fully-merged state.
+
+bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
+ SmallVector<StoreSDNode *, 8> ChainedStores;
+ StoreSDNode *STChain = St;
+ // Intervals records which offsets from BaseIndex have been covered. In
+ // the common case, every store writes to the immediately previous address
+ // space and thus merged with the previous interval at insertion time.
+
+ using IMap =
+ llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
+ IMap::Allocator A;
+ IMap Intervals(A);
+
// This holds the base pointer, index, and the offset in bytes from the base
// pointer.
- BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr());
+ const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
// We must have a base and an offset.
- if (!BasePtr.Base.getNode())
+ if (!BasePtr.getBase().getNode())
return false;
// Do not handle stores to undef base pointers.
- if (BasePtr.Base.getOpcode() == ISD::UNDEF)
+ if (BasePtr.getBase().isUndef())
return false;
- SmallVector<StoreSDNode *, 8> ChainedStores;
- ChainedStores.push_back(St);
+ // Add ST's interval.
+ Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
- // Walk up the chain and look for nodes with offsets from the same
- // base pointer. Stop when reaching an instruction with a different kind
- // or instruction which has a different base pointer.
- StoreSDNode *Index = St;
- while (Index) {
+ while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
// If the chain has more than one use, then we can't reorder the mem ops.
- if (Index != St && !SDValue(Index, 0)->hasOneUse())
+ if (!SDValue(Chain, 0)->hasOneUse())
break;
-
- if (Index->isVolatile() || Index->isIndexed())
+ if (Chain->isVolatile() || Chain->isIndexed())
break;
// Find the base pointer and offset for this memory node.
- BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr());
-
+ const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
// Check that the base pointer is the same as the original one.
- if (!Ptr.equalBaseIndex(BasePtr))
+ int64_t Offset;
+ if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
+ break;
+ int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
+ // Make sure we don't overlap with other intervals by checking the ones to
+ // the left or right before inserting.
+ auto I = Intervals.find(Offset);
+ // If there's a next interval, we should end before it.
+ if (I != Intervals.end() && I.start() < (Offset + Length))
+ break;
+ // If there's a previous interval, we should start after it.
+ if (I != Intervals.begin() && (--I).stop() <= Offset)
break;
+ Intervals.insert(Offset, Offset + Length, Unit);
- // Find the next memory operand in the chain. If the next operand in the
- // chain is a store then move up and continue the scan with the next
- // memory operand. If the next operand is a load save it and use alias
- // information to check if it interferes with anything.
- SDNode *NextInChain = Index->getChain().getNode();
- while (true) {
- if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) {
- // We found a store node. Use it for the next iteration.
- ChainedStores.push_back(STn);
- Index = STn;
- break;
- } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) {
- NextInChain = Ldn->getChain().getNode();
- continue;
- } else {
- Index = nullptr;
- break;
- }
- }
+ ChainedStores.push_back(Chain);
+ STChain = Chain;
}
- bool MadeChange = false;
- SmallVector<std::pair<StoreSDNode *, SDValue>, 8> BetterChains;
+ // If we didn't find a chained store, exit.
+ if (ChainedStores.size() == 0)
+ return false;
- for (StoreSDNode *ChainedStore : ChainedStores) {
- SDValue Chain = ChainedStore->getChain();
- SDValue BetterChain = FindBetterChain(ChainedStore, Chain);
+ // Improve all chained stores (St and ChainedStores members) starting from
+ // where the store chain ended and return single TokenFactor.
+ SDValue NewChain = STChain->getChain();
+ SmallVector<SDValue, 8> TFOps;
+ for (unsigned I = ChainedStores.size(); I;) {
+ StoreSDNode *S = ChainedStores[--I];
+ SDValue BetterChain = FindBetterChain(S, NewChain);
+ S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
+ S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
+ TFOps.push_back(SDValue(S, 0));
+ ChainedStores[I] = S;
+ }
+
+ // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
+ SDValue BetterChain = FindBetterChain(St, NewChain);
+ SDValue NewST;
+ if (St->isTruncatingStore())
+ NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
+ St->getBasePtr(), St->getMemoryVT(),
+ St->getMemOperand());
+ else
+ NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
+ St->getBasePtr(), St->getMemOperand());
- if (Chain != BetterChain) {
- MadeChange = true;
- BetterChains.push_back(std::make_pair(ChainedStore, BetterChain));
- }
- }
+ TFOps.push_back(NewST);
+
+ // If we improved every element of TFOps, then we've lost the dependence on
+ // NewChain to successors of St and we need to add it back to TFOps. Do so at
+ // the beginning to keep relative order consistent with FindBetterChains.
+ auto hasImprovedChain = [&](SDValue ST) -> bool {
+ return ST->getOperand(0) != NewChain;
+ };
+ bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
+ if (AddNewChain)
+ TFOps.insert(TFOps.begin(), NewChain);
+
+ SDValue TF = DAG.getNode(ISD::TokenFactor, SDLoc(STChain), MVT::Other, TFOps);
+ CombineTo(St, TF);
+
+ AddToWorklist(STChain);
+ // Add TF operands worklist in reverse order.
+ for (auto I = TF->getNumOperands(); I;)
+ AddToWorklist(TF->getOperand(--I).getNode());
+ AddToWorklist(TF.getNode());
+ return true;
+}
+
+bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
+ if (OptLevel == CodeGenOpt::None)
+ return false;
+
+ const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
+
+ // We must have a base and an offset.
+ if (!BasePtr.getBase().getNode())
+ return false;
+
+ // Do not handle stores to undef base pointers.
+ if (BasePtr.getBase().isUndef())
+ return false;
- // Do all replacements after finding the replacements to make to avoid making
- // the chains more complicated by introducing new TokenFactors.
- for (auto Replacement : BetterChains)
- replaceStoreChain(Replacement.first, Replacement.second);
+ // Directly improve a chain of disjoint stores starting at St.
+ if (parallelizeChainedStores(St))
+ return true;
- return MadeChange;
+ // Improve St's Chain..
+ SDValue BetterChain = FindBetterChain(St, St->getChain());
+ if (St->getChain() != BetterChain) {
+ replaceStoreChain(St, BetterChain);
+ return true;
+ }
+ return false;
}
/// This is the entry point for the file.
-void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis &AA,
+void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
CodeGenOpt::Level OptLevel) {
/// This is the main entry point to this class.
DAGCombiner(*this, AA, OptLevel).Run(Level);