#include "pmap.h"

PhotonMap::PhotonMap(int maxP) {
  double theta;

  // Build the angle lookup table
  for(int i = 0; i < 256; i++) {
    theta = (double)i * (1./256.) * M_PI;
    cost[i] = cos(theta);
    sint[i] = sin(theta);
    theta *= 2.;
    cosp[i] = cos(theta);
    sinp[i] = sin(theta);
  }

  numStored = 0;
  halfStored = 0;
  maxPhotons = maxP;
  lastScale = 1;
  photons = (Photon*)malloc(sizeof(Photon) * (maxPhotons+1));

  min[0] = min[1] = min[2] = 1e8f;
  max[0] = max[1] = max[2] = -1e8f;
}

PhotonMap::~PhotonMap() {
  free(photons);
}

void PhotonMap::Insert(const Point p, const float pw[3], const Vector v) {
  if(numStored > maxPhotons)
    return;

  numStored++;
  Photon *const cur = &photons[numStored];

  cur->pos[0] = p.x; cur->pos[1] = p.y; cur->pos[2] = p.z;
  for(int i = 0; i < 3; i++) {
	  cur->pow[i] = pw[i];

	  if(cur->pos[i] < min[i])
	  	min[i] = cur->pos[i];
	  if(cur->pos[i] > max[i])
	  	max[i] = cur->pos[i];
  }

  int theta = int(acos(v.z)*(256.0/M_PI));
  if(theta > 255)
  	cur->theta = 255;
  else
  	cur->theta = (unsigned char)theta;

  int phi = int(atan2(v.y, v.x)*(256.0/(2.0*M_PI)));
  if(phi > 255)
  	cur->phi = 255;
  else if(phi < 0)
  	cur->phi = (unsigned char)(phi+256);
  else
  	cur->phi = (unsigned char)phi;
  /*cur->pos[0] = p.x; cur->pos[1] = p.y; cur->pos[2] = p.z;
  cur->pow[0] = pw[0]; cur->pow[1] = pw[1]; cur->pow[2] = pw[2];
  if(p.x < min[0]) min[0] = p.x;
  if(p.x > max[0]) max[0] = p.x;
  if(p.y < min[1]) min[1] = p.y;
  if(p.y > max[1]) min[1] = p.y;
  if(p.z < min[2]) min[2] = p.z;
  if(p.z > max[2]) max[2] = p.z;

  int theta, phi;

  theta = (int)(acos(v.z)*(256./M_PI));
  phi = (int)(atan2(v.y, v.x)*(256./(2.*M_PI)));
  if(theta > 255)
    theta = 255;
  if(phi > 255)
    phi = 255;
  else
  if(phi < 0)
    phi = phi+256;

  cur->theta = (unsigned char)theta;
  cur->phi = (unsigned char)phi;*/
}

void PhotonMap::ScalePow(const float scale) {
  for(int i = lastScale; i <= numStored; i++) {
    photons[i].pow[0] *= scale;
    photons[i].pow[1] *= scale;
    photons[i].pow[2] *= scale;
  }
  lastScale = numStored;
}

void PhotonMap::BuildTree() {
  if(numStored < 2) {
    halfStored = numStored/2 - 1;
    return;
  }

  Photon **t1 = (Photon**)malloc(sizeof(Photon*) * (numStored+1));
  Photon **t2 = (Photon**)malloc(sizeof(Photon*) * (numStored+1));

  for(int i = 0; i <= numStored; i++)
    t2[i] = &photons[i];

  BalanceSegment(t1, t2, 1, 1, numStored);
  free(t2);

  int d, j=1, k=1;
  Photon p = photons[j];

  for(int i = 1; i <= numStored; i++) {
    d = t1[j] - photons;
    t1[j] = NULL;
    if(d != k) {
      photons[j] = photons[d];
    } else {
      photons[j] = p;
      if(i < numStored) {
	for(; k <= numStored; k++)
	  if(t1[k] != NULL)
	    break;
	p = photons[k];
	j = k;
      }
      continue;
    }
    j = d;
  }
  free(t1);
  halfStored = numStored/2 - 1;
}

void PhotonMap::GetIrradiance(float ir[3], const Point p, const Normal n,
			      const float mDist, const int nPhot) const {
  ir[0] = ir[1] = ir[2] = 0;

  PhotonStack ps;
  ps.max = nPhot;
  ps.numFound = 0;
  ps.gotHeap = 0;
  ps.pos[0] = p.x; ps.pos[1] = p.y; ps.pos[2] = p.z;
  ps.distSquared = (float*)alloca(sizeof(float)*(nPhot+1));
  ps.distSquared[0] = mDist*mDist;
  ps.index = (const Photon**)alloca(sizeof(Photon*)*(nPhot+1));

  LocatePhotons(&ps, 1);
 if(ps.numFound < 8) return;
 //printf("\nFound %d photons",ps.numFound);

  for(int i = 1; i < ps.numFound; i++) {
    const Photon *cur = ps.index[i];
    float pdir[3];
    PhotonDir(pdir, cur);
    if ((pdir[0]*n.x + pdir[1]*n.y + pdir[2]*n.z) <= 0.0f) {
    	ir[0] += cur->pow[0];
    	ir[1] += cur->pow[1];
   	ir[2] += cur->pow[2];
	}
  }

  const float density = (1./M_PI)/(ps.distSquared[0]);

  ir[0] *= density;
  ir[1] *= density;
  ir[2] *= density;
}

void PhotonMap::LocatePhotons(PhotonStack *const ps, const int index) const {
  const Photon *p = &photons[index];
  float dist1, dist2;

  if(index < halfStored) {
    dist1 = ps->pos[p->kdtree] - p->pos[p->kdtree];
    if(dist1 > 0.) {
      LocatePhotons(ps, 2*index+1);
      if((dist1*dist1) < ps->distSquared[0])
	LocatePhotons(ps, 2*index);
    } else {
      LocatePhotons(ps, 2*index);
      if((dist1*dist1) < ps->distSquared[0])
	LocatePhotons(ps, 2*index+1);
    }
  }

  dist1 = p->pos[0] - ps->pos[0];
  dist2 = dist1*dist1;
  dist1 = p->pos[1] - ps->pos[1];
  dist2 += dist1*dist1;
  dist1 = p->pos[2] - ps->pos[2];
  dist2 += dist1*dist1;

  if(dist2 < ps->distSquared[0]) {
    if(ps->numFound < ps->max) {
      ps->numFound++;
      ps->distSquared[ps->numFound] = dist2;
      ps->index[ps->numFound] = p;
    } else {
      int j, parent;

      if(ps->gotHeap == 0) {
	float dst2;
	const Photon *p2;
	int half_found = ps->numFound >> 1;
	for(int k = half_found; k >= 1; k--) {
	  parent = k;
	  p2 = ps->index[k];
	  dst2 = ps->distSquared[k];
	  while(parent <= half_found) {
	    j = parent+parent;
	    if(j < ps->numFound &&
	       ps->distSquared[j] < ps->distSquared[j+1])
	      j++;
	    if(dst2 >= ps->distSquared[j])
	      break;
	    ps->distSquared[parent] = ps->distSquared[j];
	    ps->index[parent] = ps->index[j];
	    parent = j;
	  }
	  ps->distSquared[parent] = dst2;
	  ps->index[parent] = p2;
	}
	ps->gotHeap = 1;
      }
      parent = 1;
      j = 2;
      while(j <= ps->numFound) {
	if(j < ps->numFound &&
	   ps->distSquared[j] < ps->distSquared[j+1])
	  j++;
	if(dist2 > ps->distSquared[j])
	  break;
	ps->distSquared[parent] = ps->distSquared[j];
	ps->index[parent] = ps->index[j];
	parent = j;
	j += j;
      }
      ps->index[parent] = p;
      ps->distSquared[parent] = dist2;
      ps->distSquared[0] = ps->distSquared[1];
    }
  }
}

void PhotonMap::BalanceSegment(Photon **t1, Photon **t2, const int index,
			       const int start, const int end) {
  int med = 1;
  while((4*med) <= (end-start+1))
    med += med;

  if((3*med) <= (end-start+1)) {
    med += med;
    med += start-1;
  } else {
    med = end-med+1;
  }

  int axis = 2;
  if((max[0]-min[0]) > (max[1]-min[1]) &&
     (max[0]-min[0]) > (max[2]-min[2]))
    axis = 0;
  else
  if((max[1]-min[1]) > (max[2]-min[2]))
    axis = 1;

  MedianSplit(t2, start, end, med, axis);
  t1[index] = t2[med];
  t1[index]->kdtree = axis;
  if(med > start) {
    if(start < med-1) {
      const float tmp = max[axis];
      max[axis] = t1[index]->pos[axis];
      BalanceSegment(t1, t2, 2*index, start, med-1);
      max[axis] = tmp;
    } else {
      t1[2*index] = t2[start];
    }
  }

  if(med < end) {
    if(med+1 < end) {
      const float tmp = min[axis];
      min[axis] = t1[index]->pos[axis];
      BalanceSegment(t1, t2, 2*index+1, med+1, end);
      min[axis] = tmp;
    } else {
      t1[2*index+1] = t2[end];
    }
  }
}

#define swap(p,a,b) { Photon *p2 = p[a]; p[a] = p[b]; p[b] = p2; }

void PhotonMap::MedianSplit(Photon **p, const int start, const int end,
			    const int median, const int axis) {
  int left = start;
  int right = end;

  while(right > left) {
    const float v = p[right]->pos[axis];
    int i = left-1;
    int j = right;
    for(;;) {
      while(p[++i]->pos[axis] < v);
      while(p[--j]->pos[axis] > v && j > left);
      if(i >= j) break;
      swap(p, i, j);
    }

    swap(p, i, right);
    if(i >= median)
      right = i-1;
    if(i <= median)
      left = i+1;
  }
}


void PhotonMap::PhotonDir(float *dir, const Photon *p) const {
	dir[0] = sint[p->theta]*cosp[p->phi];
	dir[1] = sint[p->theta]*sinp[p->phi];
	dir[2] = cost[p->theta];
}