/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.rdf4j.query.algebra.evaluation.optimizer;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.Dataset;
import org.eclipse.rdf4j.query.algebra.AbstractQueryModelNode;
import org.eclipse.rdf4j.query.algebra.BindingSetAssignment;
import org.eclipse.rdf4j.query.algebra.Extension;
import org.eclipse.rdf4j.query.algebra.Join;
import org.eclipse.rdf4j.query.algebra.LeftJoin;
import org.eclipse.rdf4j.query.algebra.QueryModelNode;
import org.eclipse.rdf4j.query.algebra.StatementPattern;
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.Var;
import org.eclipse.rdf4j.query.algebra.ZeroLengthPath;
import org.eclipse.rdf4j.query.algebra.evaluation.QueryOptimizer;
import org.eclipse.rdf4j.query.algebra.evaluation.impl.EvaluationStatistics;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractSimpleQueryModelVisitor;
import org.eclipse.rdf4j.query.algebra.helpers.StatementPatternVisitor;
import org.eclipse.rdf4j.query.algebra.helpers.TupleExprs;

public class QueryJoinOptimizer
implements QueryOptimizer {
    protected final EvaluationStatistics statistics;
    private final boolean trackResultSize;

    public QueryJoinOptimizer(EvaluationStatistics statistics) {
        this(statistics, false);
    }

    public QueryJoinOptimizer(EvaluationStatistics statistics, boolean trackResultSize) {
        this.statistics = statistics;
        this.trackResultSize = trackResultSize;
    }

    @Override
    public void optimize(TupleExpr tupleExpr, Dataset dataset, BindingSet bindings) {
        tupleExpr.visit(new JoinVisitor(this.statistics, this.trackResultSize));
    }

    private static int getUnionSize(Set<String> currentListNames, Set<String> candidateBindingNames) {
        int count = 0;
        for (String n : currentListNames) {
            if (candidateBindingNames.contains(n)) continue;
            ++count;
        }
        return candidateBindingNames.size() + count;
    }

    private static int getJoinSize(Set<String> currentListNames, Set<String> names) {
        int count = 0;
        for (String name : names) {
            if (!currentListNames.contains(name)) continue;
            ++count;
        }
        return count;
    }

    private static boolean hasCachedCardinality(TupleExpr tupleExpr) {
        return tupleExpr instanceof AbstractQueryModelNode && ((AbstractQueryModelNode)((Object)tupleExpr)).isCardinalitySet();
    }

    private static class JoinVisitor
    extends AbstractSimpleQueryModelVisitor<RuntimeException> {
        private final EvaluationStatistics statistics;
        Set<String> boundVars = new HashSet<String>();

        protected JoinVisitor(EvaluationStatistics statistics, boolean trackResultSize) {
            super(trackResultSize);
            this.statistics = statistics;
        }

        @Override
        public void meet(LeftJoin leftJoin) {
            leftJoin.getLeftArg().visit(this);
            Set<String> origBoundVars = this.boundVars;
            try {
                this.boundVars = new HashSet<String>(this.boundVars);
                this.boundVars.addAll(leftJoin.getLeftArg().getBindingNames());
                leftJoin.getRightArg().visit(this);
            }
            finally {
                this.boundVars = origBoundVars;
            }
        }

        @Override
        public void meet(StatementPattern node) throws RuntimeException {
            node.setResultSizeEstimate(Math.max(this.statistics.getCardinality(node), node.getResultSizeEstimate()));
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void optimizePriorityJoin(Set<String> origBoundVars, TupleExpr join) {
            Set<String> saveBoundVars = this.boundVars;
            try {
                this.boundVars = new HashSet<String>(origBoundVars);
                join.visit(this);
            }
            finally {
                this.boundVars = saveBoundVars;
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void meet(Join node) {
            Set<String> origBoundVars = this.boundVars;
            try {
                List<TupleExpr> priorityArgs;
                this.boundVars = new HashSet<String>(this.boundVars);
                ArrayList<TupleExpr> joinArgs = this.getJoinArgs(node, new ArrayList());
                List<TupleExpr> orderedExtensions = this.getExtensionTupleExprs(joinArgs);
                joinArgs.removeAll(orderedExtensions);
                List<TupleExpr> orderedSubselects = this.reorderSubselects(this.getSubSelects(joinArgs));
                joinArgs.removeAll(orderedSubselects);
                if (orderedExtensions.isEmpty()) {
                    priorityArgs = orderedSubselects;
                } else if (orderedSubselects.isEmpty()) {
                    priorityArgs = orderedExtensions;
                } else {
                    priorityArgs = new ArrayList<TupleExpr>(orderedExtensions.size() + orderedSubselects.size());
                    priorityArgs.addAll(orderedExtensions);
                    priorityArgs.addAll(orderedSubselects);
                }
                ArrayList<TupleExpr> orderedJoinArgs = new ArrayList<TupleExpr>(joinArgs.size());
                if (joinArgs.size() > 0) {
                    Map<TupleExpr, Double> cardinalityMap = Collections.emptyMap();
                    HashMap<TupleExpr, List<Var>> varsMap = new HashMap<TupleExpr, List<Var>>();
                    for (TupleExpr tupleExpr : joinArgs) {
                        if (tupleExpr instanceof Join) continue;
                        double cardinality = this.statistics.getCardinality(tupleExpr);
                        tupleExpr.setResultSizeEstimate(Math.max(cardinality, tupleExpr.getResultSizeEstimate()));
                        if (!QueryJoinOptimizer.hasCachedCardinality(tupleExpr)) {
                            if (cardinalityMap.isEmpty()) {
                                cardinalityMap = new HashMap<TupleExpr, Double>();
                            }
                            cardinalityMap.put(tupleExpr, cardinality);
                        }
                        if (tupleExpr instanceof ZeroLengthPath) {
                            varsMap.put(tupleExpr, ((ZeroLengthPath)tupleExpr).getVarList());
                            continue;
                        }
                        varsMap.put(tupleExpr, this.getStatementPatternVars(tupleExpr));
                    }
                    HashMap<Var, Integer> varFreqMap = new HashMap<Var, Integer>((varsMap.size() + 1) * 2);
                    for (List varList : varsMap.values()) {
                        this.fillVarFreqMap(varList, varFreqMap);
                    }
                    while (!joinArgs.isEmpty()) {
                        TupleExpr tupleExpr = this.selectNextTupleExpr(joinArgs, cardinalityMap, varsMap, varFreqMap);
                        joinArgs.remove(tupleExpr);
                        orderedJoinArgs.add(tupleExpr);
                        tupleExpr.visit(this);
                        this.boundVars.addAll(tupleExpr.getBindingNames());
                    }
                }
                TupleExpr priorityJoins = null;
                if (priorityArgs.size() > 0) {
                    priorityJoins = priorityArgs.get(0);
                    for (int i = 1; i < priorityArgs.size(); ++i) {
                        priorityJoins = new Join(priorityJoins, priorityArgs.get(i));
                    }
                }
                if (orderedJoinArgs.size() > 0) {
                    int i = orderedJoinArgs.size() - 1;
                    TupleExpr replacement = (TupleExpr)orderedJoinArgs.get(i);
                    --i;
                    while (i >= 0) {
                        replacement = new Join((TupleExpr)orderedJoinArgs.get(i), replacement);
                        --i;
                    }
                    if (priorityJoins != null) {
                        replacement = new Join(priorityJoins, replacement);
                    }
                    node.replaceWith(replacement);
                    if (priorityJoins != null) {
                        this.optimizePriorityJoin(origBoundVars, priorityJoins);
                    }
                } else {
                    node.replaceWith(priorityJoins);
                }
            }
            finally {
                this.boundVars = origBoundVars;
            }
        }

        protected <L extends List<TupleExpr>> L getJoinArgs(TupleExpr tupleExpr, L joinArgs) {
            if (tupleExpr instanceof Join) {
                Join join = (Join)tupleExpr;
                this.getJoinArgs(join.getLeftArg(), joinArgs);
                this.getJoinArgs(join.getRightArg(), joinArgs);
            } else {
                joinArgs.add((TupleExpr)tupleExpr);
            }
            return joinArgs;
        }

        protected List<Var> getStatementPatternVars(TupleExpr tupleExpr) {
            if (tupleExpr instanceof StatementPattern) {
                return ((StatementPattern)tupleExpr).getVarList();
            }
            if (tupleExpr instanceof BindingSetAssignment) {
                return List.of();
            }
            return new StatementPatternVarCollector(tupleExpr).getVars();
        }

        protected <M extends Map<Var, Integer>> void fillVarFreqMap(List<Var> varList, M varFreqMap) {
            if (varList.isEmpty()) {
                return;
            }
            for (Var var : varList) {
                varFreqMap.compute((Var)var, (k, v) -> {
                    if (v == null) {
                        return 1;
                    }
                    return v + 1;
                });
            }
        }

        protected List<Extension> getExtensions(List<TupleExpr> expressions) {
            ArrayList<Extension> extensions = new ArrayList<Extension>();
            for (TupleExpr expr : expressions) {
                if (!(expr instanceof Extension)) continue;
                extensions.add((Extension)expr);
            }
            return extensions;
        }

        private List<TupleExpr> getExtensionTupleExprs(List<TupleExpr> expressions) {
            if (expressions.isEmpty()) {
                return List.of();
            }
            List<TupleExpr> extensions = List.of();
            for (TupleExpr expr : expressions) {
                if (!TupleExprs.containsExtension(expr)) continue;
                if (extensions.isEmpty()) {
                    extensions = List.of(expr);
                    continue;
                }
                if (extensions.size() == 1) {
                    extensions = new ArrayList<TupleExpr>(extensions);
                }
                extensions.add(expr);
            }
            return extensions;
        }

        protected List<TupleExpr> getSubSelects(List<TupleExpr> expressions) {
            if (expressions.isEmpty()) {
                return List.of();
            }
            List<TupleExpr> subselects = List.of();
            for (TupleExpr expr : expressions) {
                if (!TupleExprs.containsSubquery(expr)) continue;
                if (subselects.isEmpty()) {
                    subselects = List.of(expr);
                    continue;
                }
                if (subselects.size() == 1) {
                    subselects = new ArrayList<TupleExpr>(subselects);
                }
                subselects.add(expr);
            }
            return subselects;
        }

        protected List<TupleExpr> reorderSubselects(List<TupleExpr> subSelects) {
            if (subSelects.size() == 1) {
                return subSelects;
            }
            ArrayList<TupleExpr> result = new ArrayList<TupleExpr>();
            if (subSelects.isEmpty()) {
                return result;
            }
            HashMap joinSizes = new HashMap();
            int maxJoinSize = 0;
            for (int i = 0; i < subSelects.size(); ++i) {
                TupleExpr firstArg = subSelects.get(i);
                for (int j = i + 1; j < subSelects.size(); ++j) {
                    TupleExpr secondArg = subSelects.get(j);
                    int joinSize = QueryJoinOptimizer.getJoinSize(firstArg.getBindingNames(), secondArg.getBindingNames());
                    if (joinSize > maxJoinSize) {
                        maxJoinSize = joinSize;
                    }
                    List l = joinSizes.containsKey(joinSize) ? (List)joinSizes.get(joinSize) : new ArrayList();
                    TupleExpr[] tupleTuple = new TupleExpr[]{firstArg, secondArg};
                    l.add(tupleTuple);
                    joinSizes.put(joinSize, l);
                }
            }
            TupleExpr[] maxUnionTupleTuple = null;
            int currentUnionSize = -1;
            List list = (List)joinSizes.get(maxJoinSize);
            for (TupleExpr[] tupleTuple : list) {
                Set<String> names = tupleTuple[0].getBindingNames();
                names.addAll(tupleTuple[1].getBindingNames());
                int unionSize = names.size();
                if (unionSize <= currentUnionSize) continue;
                maxUnionTupleTuple = tupleTuple;
                currentUnionSize = unionSize;
            }
            assert (maxUnionTupleTuple != null);
            result.add((TupleExpr)maxUnionTupleTuple[0]);
            result.add((TupleExpr)maxUnionTupleTuple[1]);
            while (result.size() < subSelects.size()) {
                result.add(this.getNextSubselect(result, subSelects));
            }
            return result;
        }

        private TupleExpr getNextSubselect(List<TupleExpr> currentList, List<TupleExpr> joinArgs) {
            HashSet<String> currentListNames = new HashSet<String>();
            for (TupleExpr expr : currentList) {
                currentListNames.addAll(expr.getBindingNames());
            }
            TupleExpr selected = null;
            int currentUnionSize = -1;
            int currentJoinSize = -1;
            for (TupleExpr candidate : joinArgs) {
                if (currentList.contains(candidate)) continue;
                Set<String> names = candidate.getBindingNames();
                int joinSize = QueryJoinOptimizer.getJoinSize(currentListNames, names);
                Set<String> candidateBindingNames = candidate.getBindingNames();
                int unionSize = QueryJoinOptimizer.getUnionSize(currentListNames, candidateBindingNames);
                if (joinSize > currentJoinSize) {
                    selected = candidate;
                    currentJoinSize = joinSize;
                    currentUnionSize = unionSize;
                    continue;
                }
                if (joinSize != currentJoinSize || unionSize <= currentUnionSize) continue;
                selected = candidate;
                currentUnionSize = unionSize;
            }
            return selected;
        }

        protected TupleExpr selectNextTupleExpr(List<TupleExpr> expressions, Map<TupleExpr, Double> cardinalityMap, Map<TupleExpr, List<Var>> varsMap, Map<Var, Integer> varFreqMap) {
            if (expressions.size() == 1) {
                TupleExpr tupleExpr = expressions.get(0);
                if (tupleExpr.getCostEstimate() < 0.0) {
                    tupleExpr.setCostEstimate(this.getTupleExprCost(tupleExpr, cardinalityMap, varsMap, varFreqMap));
                }
                return tupleExpr;
            }
            QueryModelNode result = null;
            double lowestCost = Double.POSITIVE_INFINITY;
            for (TupleExpr tupleExpr : expressions) {
                double cost = this.getTupleExprCost(tupleExpr, cardinalityMap, varsMap, varFreqMap);
                if (!(cost < lowestCost) && result != null) continue;
                lowestCost = cost;
                result = tupleExpr;
                if (cost != 0.0) continue;
                break;
            }
            assert (result != null);
            result.setCostEstimate(lowestCost);
            return result;
        }

        protected double getTupleExprCost(TupleExpr tupleExpr, Map<TupleExpr, Double> cardinalityMap, Map<TupleExpr, List<Var>> varsMap, Map<Var, Integer> varFreqMap) {
            if (tupleExpr instanceof BindingSetAssignment) {
                Set<Var> varsUsedInOtherExpressions = varFreqMap.keySet();
                for (String assuredBindingName : tupleExpr.getAssuredBindingNames()) {
                    if (!varsUsedInOtherExpressions.contains(new Var(assuredBindingName))) continue;
                    return 0.0;
                }
            }
            double cost = QueryJoinOptimizer.hasCachedCardinality(tupleExpr) ? ((AbstractQueryModelNode)((Object)tupleExpr)).getCardinality() : cardinalityMap.get(tupleExpr).doubleValue();
            List<Var> vars = varsMap.get(tupleExpr);
            List<Var> unboundVars = this.getUnboundVars(vars);
            int constantVars = this.countConstantVars(vars);
            int nonConstantVarCount = vars.size() - constantVars;
            if (nonConstantVarCount > 0) {
                double exp = (double)unboundVars.size() / (double)nonConstantVarCount;
                cost = Math.pow(cost, exp);
            }
            if (unboundVars.isEmpty()) {
                if (nonConstantVarCount > 0) {
                    cost /= (double)nonConstantVarCount;
                }
            } else {
                int foreignVarFreq = this.getForeignVarFreq(unboundVars, varFreqMap);
                if (foreignVarFreq > 0) {
                    cost /= (double)(1 + foreignVarFreq);
                }
            }
            return cost;
        }

        private int countConstantVars(List<Var> vars) {
            int size = 0;
            for (Var var : vars) {
                if (!var.hasValue()) continue;
                ++size;
            }
            return size;
        }

        @Deprecated(forRemoval=true, since="4.1.0")
        protected List<Var> getUnboundVars(Iterable<Var> vars) {
            List<Var> ret = null;
            for (Var var : vars) {
                if (var.hasValue() || var.getName() == null || this.boundVars.contains(var.getName())) continue;
                if (ret == null) {
                    ret = Collections.singletonList(var);
                    continue;
                }
                if (ret.size() == 1) {
                    ret = new ArrayList<Var>(ret);
                }
                ret.add(var);
            }
            return ret != null ? ret : Collections.emptyList();
        }

        protected List<Var> getUnboundVars(List<Var> vars) {
            int size = vars.size();
            if (size == 0) {
                return List.of();
            }
            if (size == 1) {
                Var var = vars.get(0);
                if (!var.hasValue() && var.getName() != null && !this.boundVars.contains(var.getName())) {
                    return List.of(var);
                }
                return List.of();
            }
            List<Var> ret = null;
            for (Var var : vars) {
                if (var.hasValue() || var.getName() == null || this.boundVars.contains(var.getName())) continue;
                if (ret == null) {
                    ret = List.of(var);
                    continue;
                }
                if (ret.size() == 1) {
                    ret = new ArrayList<Var>(ret);
                }
                ret.add(var);
            }
            return ret != null ? ret : Collections.emptyList();
        }

        protected int getForeignVarFreq(List<Var> ownUnboundVars, Map<Var, Integer> varFreqMap) {
            if (ownUnboundVars.isEmpty()) {
                return 0;
            }
            if (ownUnboundVars.size() == 1) {
                return varFreqMap.get(ownUnboundVars.get(0)) - 1;
            }
            int result = -ownUnboundVars.size();
            for (Var var : new HashSet<Var>(ownUnboundVars)) {
                result += varFreqMap.get(var).intValue();
            }
            return result;
        }

        private static class StatementPatternVarCollector
        extends StatementPatternVisitor {
            private final TupleExpr tupleExpr;
            private List<Var> vars;

            public StatementPatternVarCollector(TupleExpr tupleExpr) {
                this.tupleExpr = tupleExpr;
            }

            @Override
            protected void accept(StatementPattern node) {
                if (this.vars == null) {
                    this.vars = new ArrayList<Var>(node.getVarList());
                } else {
                    this.vars.addAll(node.getVarList());
                }
            }

            public List<Var> getVars() {
                if (this.vars == null) {
                    try {
                        this.tupleExpr.visit(this);
                    }
                    catch (Exception e) {
                        throw new IllegalStateException(e);
                    }
                    if (this.vars == null) {
                        this.vars = Collections.emptyList();
                    }
                }
                return this.vars;
            }
        }
    }
}

