package erp.graph;

import java.util.*;
import erp.util.*;

/** An augmented binary tree representing a (non-strictly)
    monotonically increasing sequence.  Notable operations:

    * Compute the average value of any contiguous interval in O(log N) time

    * Ensure that all elements with index after i are at least value v in O(log N) time

    * Limit all values to v in O(log N) time

    * Scale all values according to v/(1+v) in O(N) time

    Organization:

    Each node maintains three values: a minimum value, a maximum
    value, and a sum. 

    Nominally:
    * the sum is the sum of all the children's values.
    * minimum value: no child is allowed to have a smaller value
    * maximum value of the children
    
    A parent that has minimum==maximum overrides any data of its
    children: all the children have that value.

    Indexing scheme: we use an in-place allocation that uses exactly
    2N nodes to store N values. We use the term "address" to mean the
    node number. The address for any index is always 2*i. The parents
    has an address equal to the average of its children. 

                   7                    ROOT ADDRESS
          3                11           PARENT ADDRESSES
      1       5       9         13      PARENT ADDRESSES
    0   2   4   6   8   10   12   14    ADDRESS OF LEAF
   (0) (1) (2) (3) (4)  (5)  (6)  (7)   INDEX
  
   Since nodes can override their children, most operations proceed
   from the root and traverse downwards. Most write operations crawl
   down the tree looking for the nodes that need to be modified in
   order to achieve the correct changes, then crawl back up the tree,
   fixing up the parent nodes of the changed children.

   The run-time performance derives from the fact that the tree
   contains monotonically increasing values, and that operations deal
   with contiguous blocks of indices. Any contiguous block of indices
   has at most log N ancestors that have only the affected indices and
   descendents. In the figure above, the indices (5-7) has a single
   ancestor: 11. The indices 1-7 have ancestors 2, 5, and 11. Write
   operations need only locate these nodes.

**/
public final class LearningRatesTree implements LearningRates
{
    DoubleVector vs;
    int size;

    static final int STORAGE_PER_NODE = 3;
    int ROOT_NODE;
    int NUM_NODES;

    public LearningRatesTree()
    {
	this(4);
    }

    public LearningRatesTree(int sz)
    {
	ensureSize(sz);
    }

    public LearningRatesTree copy()
    {
	LearningRatesTree lt = new LearningRatesTree(size);
	lt.size = size;
	lt.vs = vs.copy();
	lt.ROOT_NODE = ROOT_NODE;
	lt.NUM_NODES = NUM_NODES;
	return lt;
    }

    public int size()
    {
	return size;
    }

    int nextPowerOfTwo(int n)
    {
	n |= (n>>1);
	n |= (n>>2);
	n |= (n>>4);
	n |= (n>>8);
	n |= (n>>16);

	return n + 1;
    }

    public void ensureSize(int newsize)
    {
	if (newsize <= size)
	    return;

	// constraint size to be a power of two: this eliminates many
	// special cases when crawling the tree.
	if ((newsize&(newsize-1))!=0)
	    newsize = nextPowerOfTwo(newsize);

	int vssize = 2*STORAGE_PER_NODE*newsize;

	if (vs == null)
	    vs = new DoubleVector(vssize);

	vs.addZeros(vssize - vs.size());

	int newRootNode = nextPowerOfTwo(newsize)/2 - 1;
	int newMaxNodes = 2*newsize;

	if (size > 0)
	    {
		int node = ROOT_NODE;
		while (node != newRootNode)
		    {
			node = parent(node);
			
			int rightchild = rightChild(node);
			int leftchild = leftChild(node);

			setMin(rightchild, getMax(leftchild));
			setMax(rightchild, getMax(leftchild));
			setSum(rightchild, numberOfChildren(rightchild) * getMax(leftchild));

			fixup(node);
		    }
	    }

	ROOT_NODE = newRootNode;
	NUM_NODES = newMaxNodes;
	size = newsize;
    }

    public void debug()
    {
	debug(size);
    }

    public void debug(int maxsize)
    {
	int TAB = 10;

	ArrayList<String> lines = new ArrayList<String>();
	ArrayList<Integer> tabs = new ArrayList<Integer>();
	
	for (int step = 0; step < size && tabs.size()!=1; step++)
	    {
		int tabidx = 0;
		String line = "";

		ArrayList<Integer> newtabs = new ArrayList<Integer>();

		for (int idx = (1<<step)-1; idx < 2*maxsize; idx += 2<<step)
		    {
			//			String cell = String.format("[%d] %.3f,%.3f,%.3f", idx, getMin(idx), getMax(idx), getSum(idx));
			String cell = String.format("[%d] %.0f,%.0f,%.0f", idx, getMin(idx), getMax(idx), getSum(idx));

			int pos;
			if (step == 0)
			    pos = idx * TAB + TAB/2;
			else 
			    {
				int tab0 = tabs.get(tabidx++);
				int tab1 = tabs.get(tabidx++);
				pos = (tab1+tab0)/2;
			    }
			
			line=line+spaces(pos - line.length() - cell.length()/2)+cell;

			newtabs.add(pos);
		    }
		lines.add(line);
		tabs = newtabs;
	    }

	for (int i = lines.size()-1; i>=0; i--)
	    System.out.println(lines.get(i));

	double values[] = dump();

	for (int i = 0; i < Math.min(16, size); i++)
	    {
		System.out.printf("%4d %6.5f %6.5f %10.5f\n", i, get(i), values[i], cumulativeSum(i));
	    }

	System.out.println("\n");
    }

    String spaces(int c)
    {
	if (c>0)
	    return String.format("%"+c+"s","");
	return "";
    }

    double getMin(int node)
    {
	return vs.get(node*STORAGE_PER_NODE);
    }

    void setMin(int node, double v)
    {
	vs.set(node*STORAGE_PER_NODE, v);
    }
    
    double getMax(int node)
    {
	return vs.get(node*STORAGE_PER_NODE + 1);
    }

    void setMax(int node, double v)
    {
	vs.set(node*STORAGE_PER_NODE+1, v);
    }

    double getSum(int node)
    {
	return vs.get(node*STORAGE_PER_NODE + 2);
    }

    void setSum(int node, double v)
    {
	vs.set(node*STORAGE_PER_NODE+2, v);
    }

    static final int idx2node(int idx)
    {
	return idx*2;
    }

    static final boolean isLeftChild(int node)
    {
	int v = node + 1;
	int w = v^(v&(v-1));
	return (v&((2*w)))==0;
    }

    static final boolean isRightChild(int node)
    {
	return !isLeftChild(node);
    }

    static final int rightChild(int node)
    {
	int v = (node + 1)/2;
	v = v^(v&(v-1));

	return node + v;
    }

    static final int leftChild(int node)
    {
	int v = (node + 1)/2;
	v = v^(v&(v-1));

	return node - v;
    }

    static final int parent(int node)
    {
	int v = (node+1);
	int lo = v^(v&(v-1));
	
	return (node&(~(lo*2)))+lo;
    }

    static final int rightSibling(int node)
    {
	int v = node + 1;
	int lo = v^(v&(v-1));

	return v + lo*2 -1 ;
    }

    static final int leftSibling(int node)
    {
	int v = node + 1;
	int lo = v^(v&(v-1));

	return v - lo*2 - 1 ;
    }

    static final boolean isLeaf(int node)
    {
	return numberOfChildren(node)==1;
    }

    static final int numberOfChildren(int node)
    {
	int v = node+1;

	return v^(v&(v-1));
    }

    /** setLowerLimit is a bit more complicated than setUpperLimit
     * because of the limitation on which indices we can modify: it
     * means that we must sometimes dive more deeply into the graph,
     * past a place where min==max. Consequently, as we're descending,
     * we must track the override value and propagate it as we go.
     **/
    public void setLowerLimit(int idx, double llimit)
    {
	setLowerLimitRecurse(idx2node(idx), ROOT_NODE, llimit, false, 0);
    }

    void setLowerLimitRecurse(int limitnode, int node, double llimit, boolean overriding, double override_value)
    {
	int rightmargin = (rightSibling(node) + node)/2 - 1;
	int leftmargin  = (leftSibling(node) + node)/2 + 1;

	if (overriding)
	    {
		setMin(node, override_value);
		setMax(node, override_value);
		setSum(node, numberOfChildren(node)*override_value);
	    }

	if (rightmargin < limitnode)
	    return;

	double min = getMin(node), max = getMax(node);

	if (getMin(node) >= llimit) // nothing to do
	    return;

	if (min == max && leftmargin >= limitnode)  
	    {
		setMin(node, llimit);
		setMax(node, llimit);
		setSum(node, numberOfChildren(node)*llimit);
		return;
	    }

	if (min == max && !overriding)
	    {
		overriding = true;
		override_value = min;
	    }

	setLowerLimitRecurse(limitnode, leftChild(node), llimit, overriding, override_value);
	setLowerLimitRecurse(limitnode, rightChild(node), llimit, overriding, override_value);

	fixup(node);
    }

    public void setUpperLimit(double ulimit)
    {
	setUpperLimitRecurse(ROOT_NODE, ulimit);
    }

    void setUpperLimitRecurse(int node, double ulimit)
    {
	double min = getMin(node), max = getMax(node);

	if (getMax(node) <= ulimit) // nothing to do
	    return;

	if (min == max)  // leaf
	    {
		setMin(node, ulimit);
		setMax(node, ulimit);
		setSum(node, numberOfChildren(node)*ulimit);
		return;
	    }

	setUpperLimitRecurse(leftChild(node), ulimit);
	setUpperLimitRecurse(rightChild(node), ulimit);

	fixup(node);
    }


    /** recompute this node's min, max, sum from its children. It is
	important that the children be initialized properly (via
	prepareChild) or some other method, or you'll get garbage
	here. **/
    void fixup(int node)
    {
	int left = leftChild(node), right = rightChild(node);

	setMax(node, getMax(right)); 
	setMin(node, getMin(left)); 
	setSum(node, getSum(left) + getSum(right));
    }

    public double cumulativeSum(int idx)
    {
	if (idx < 0)
	    return 0;

	int node = ROOT_NODE;
	int goalnode = idx2node(idx);

	double sum = 0;
	int nodesToLeft = 0; // a count of the nodes that we've accounted for to our left

	while (true)
	    {		
		double min = getMin(node), max = getMax(node);

		if (min == max)
		    {
			// how many nodes to our left in this sub tree?
			//			System.out.println(idx +" "+nodesToLeft);
			sum += min*(idx - nodesToLeft + 1);
			return sum;
		    }

		// we need to turn left or right.
		int left = leftChild(node), right = rightChild(node);

		if (node > goalnode) {
		    node = left;
		} else {
		    sum += getSum(left);
		    nodesToLeft += numberOfChildren(left);
		    node = right;
		}
	    }
    }

    public double mean(int idx0, int idx1)
    {
	double sum0 = cumulativeSum(idx0-1);
	double sum1 = cumulativeSum(idx1);

	return (sum1-sum0)/(idx1-idx0+1);
    }

    public double get(int idx)
    {
	int goalnode = idx2node(idx);
	int node = ROOT_NODE;

	// traverse from the root to the leaf, ensuring that there are
	// no invalid nodes along the path.
	while (true)
	    {
		double min = getMin(node), max = getMax(node);
			
		if (min == max || node == goalnode)
		    return min;
		
		if (node > goalnode)  // go right
		    node = leftChild(node);
		else
		    node = rightChild(node);
	    }
    }

    void dump_recurse(int node, double[] values)
    {
	double min = getMin(node), max = getMax(node);

	if (isLeaf(node))
	    {
		values[node/2] = getMin(node);
		return;
	    }

	if (min == max)
	    {
		int low = (leftSibling(node) + node)/2 + 1;
		int high = (rightSibling(node) + node)/2 -1;

		for (int i = low; i <= high; i++)
		    values[i/2] = min;

		return;
	    }

	dump_recurse(leftChild(node), values);
	dump_recurse(rightChild(node), values);
    }

    double[] dump()
    {
	double values[] = new double[size];

	dump_recurse(ROOT_NODE, values);

	return values;
    }

    /** Decrease each element in the (implicit) array according to n' = n/(n+1) **/
    public void age()
    {
 	double values[] = dump();
	vs.setToZero();

	// compute new leaf values.
	for (int i = 0; i < values.length; i++)
	    {
		double newval = values[i]/(1+values[i]);

		setMin(i*2, newval);
		setMax(i*2, newval);
		setSum(i*2, newval);
	    }

	int rowoffset = 1;
	int rowinc = 4;

	while (true)
	    {
		for (int i = rowoffset; i < NUM_NODES; i+=rowinc)
		    fixup(i);
		
		if (rowoffset == ROOT_NODE)
		    break;

		rowoffset = rowoffset*2 + 1;
		rowinc *=2;
	    }
    }

    public static void main(String args[])
    {
	for (int i = 0; i < 15; i++)
	    {
		//		System.out.printf("%4d parent %4d\n", i, parent(i));
		// System.out.printf("%4d isLeftChild %b\n", i, isLeftChild(i));
		//		System.out.printf("%4d children are %4d, %4d\n", i, leftChild(i), rightChild(i));
				System.out.printf("%4d right sibling is %4d\n", i, rightSibling(i));
				System.out.printf("%4d left sibling is %4d\n", i, leftSibling(i));
		//		System.out.printf("%4d has children: %4d\n", i, numberOfChildren(i));
	    }

	LearningRatesTree lrates = new LearningRatesTree(2);

	lrates.setLowerLimit(0, 1.0);
	lrates.debug();

	lrates.ensureSize(8);
	lrates.debug();

	lrates.setLowerLimit(4, 1.0);
	lrates.debug();

	lrates.setLowerLimit(3, 4.0);
	lrates.debug();

	lrates.setLowerLimit(5, 0);
	lrates.debug();

	lrates.setLowerLimit(6,8);
	lrates.debug();

	lrates.setUpperLimit(2);
	lrates.debug();
    }

}
