package org.wikiwebserver.util;

import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.wikiwebserver.core.ForgetfulMap;

public class VelocityEnforcement {
	
	public static final int MAX_TYPES = 100;
	public static final int MAX_IDS = 1000;
	public static final int MAX_TIMES = 20;
	
	public static final Map<String, Map<String, List<Long>>> velocityData =
		new ForgetfulMap<String, Map<String, List<Long>>>(MAX_TYPES);

	private static List<Long> getVelocityTimes(String type, String id) {
		
	   Map<String, List<Long>> map = velocityData.get(type);
	   if (map == null) {
	      map = new ForgetfulMap<String, List<Long>>(MAX_IDS);
	      velocityData.put(type, map);
	   }
	   
	   List<Long> list = map.get(id);
	   if (list == null) {
	      list = new LinkedList<Long>();
	      map.put(id, list);
	   }
	   
	   return list;
	}
	
	public static synchronized void enforceVelocity(String type, String id, int limit, long period) 
		throws VelocityExceededException {
		
		enforceVelocity(type, id, limit, period, 0);
	}

	public static synchronized void enforceVelocity(String type, String id, int limit, long period, long penalty) 
		throws VelocityExceededException {
		
	    if (type == null || id == null) {
	    	throw new IllegalArgumentException("Type and ID for velocity enforcement can not be null.");
	    }		    
	    if (limit <= 0) {
	    	throw new IllegalArgumentException("Velocity limit must be greater than zero.");
	    }
	    
	    long time = System.currentTimeMillis();
	    
	    List<Long> velTimes = getVelocityTimes(type, id);
	    
	    if (velTimes.size() < limit) {
		    velTimes.add(new Long(time));	  
	    }
	    else {
	    	int idx = velTimes.size() - limit;
		    long initialTime = (Long) velTimes.get(idx);
		    
		    if (initialTime > time - period) {
		    	// Apply the penalty
		    	// Additional penalties will be applied if velocity check continues
		    	if (penalty > 0 && initialTime < time) {
		    		velTimes.set(idx, time + (penalty-period));
		    	}
		    	
		    	String msg = "Velocity for " + id + " in " + type + " has been exceeded.";
		        throw new VelocityExceededException(msg);
		    }
		    else {
			    velTimes.add(new Long(time));	    	
			    if (velTimes.size() > MAX_TIMES) velTimes.remove(0);		    
		    }
	    }
	}
	
	public static synchronized long getWaitTimeUntilAcceptableVelocity(String type, String id, long period, int limit) {
	    
	    if (type == null || id == null) {
	    	throw new IllegalArgumentException("Type and ID for velocity enforcement can not be null.");
	    }		    
		
	    if (limit <= 0) {
	    	throw new IllegalArgumentException("Velocity limit must be greater than zero.");
	    }
	    
		long time = System.currentTimeMillis();
	    
	    List<Long> velTimes = getVelocityTimes(type, id);
	    if (velTimes.size() < limit) return 0;
	    long initialTime = (Long) velTimes.get(velTimes.size()-limit);
	    
	    return initialTime - (time - period);
	}
}

