/*
  Mallat's Periodic Wavelet Transform

  pdwt( u[], N, J, h[], g[], L):
  [0] If J>0, then do [1] to [4]
  [1]    Allocate temp[0]=0,...,temp[N-1]=0
  [2]    Compute pcqfilter( temp[], u[], N/2, h[], g[], L)
  [3]    Compute pdwt( temp[], N/2, J-1, h[], g[], L )
  [4]    For i=0 to N-1, copy u[i] = temp[i]
  [5] Return


  Reconstruction of $N=2^JK$ Samples from Mallat's Periodic Wavelet Expansion

  ipdwt( u[], N, J, h[], g[], L):
  [0] If J>0, then  do [1-] through [4]
  [1]    Allocate temp[0]=0,...,temp[N-1]=0
  [2]    Compute ipdwt( u[], N/2, J-1, h[], g[], L )
  [3]    Compute ipcqfilter( temp[], u[], N/2, h[], g[], L )
  [4]    For i=0 to N-1, let u[i] = temp[i]
  [5] Return

*/
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

#include "pcqfilt.c"

void
pdwt(float u[], int N, int J, const float h[], const float g[], int L)
{
  if(J>0) {
    float *temp;
    int i;
    temp = (float *)calloc(N, sizeof(float)); assert(temp);
    pcqfilter(temp, u, N/2, h,g,L);
    pdwt(temp, N/2, J-1, h,g,L);
    for(i=0;i<N;i++) u[i]=temp[i];
    free(temp);
  }
  return;
}

void
ipdwt(float u[], int N, int J, const float h[], const float g[], int L)
 {
  if(J>0) {
    float *temp;
    int i;
    ipdwt(u, N/2, J-1, h,g,L);
    temp = (float *)calloc(N, sizeof(float)); assert(temp);
    ipcqfilter(temp, u, N/2, h,g,L);
    for(i=0;i<N;i++) u[i]=temp[i];
    free(temp);
  }
  return;
}



int
main(void)
{
  /* Daubechies 4 filter coefficients: */
  const float h[4] = {  0.48296291314453416, 0.83651630373780794,
			0.22414386804201339, -0.12940952255126037};
  const float g[4] = { -0.12940952255126037, -0.22414386804201339,
		       0.83651630373780794, -0.48296291314453416};
  const int L=4, J=3, N=24;
  float u[24], save[24];
  int n;

  for(n=0; n<24; n++)
    save[n]=u[n]=(float)(n*n-24*n+203);	/* test signal */

  printf("   u[%d:%d]:", 0,N-1);
  for(n=0; n<24; n++) printf(" %f", u[n]);
  putchar('\n');


  pdwt(u,N,J,h,g,L);

  printf("  Wu[%d:%d]:", 0,N-1);
  for(n=0; n<24; n++) printf(" %f", u[n]);
  putchar('\n');

  ipdwt(u,N,J,h,g,L);

  printf("W*Wu[%d:%d]:", 0,N-1);
  for(n=0; n<24; n++) printf(" %f", u[n]);
  putchar('\n');

  printf("diff[%d:%d]:", 0,N-1);
  for(n=0; n<24; n++) printf(" %f", u[n]-save[n]);
  putchar('\n');

  return 0;
}

