FFT_new.cc 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #include <cassert>
  2. #include <cstdio>
  3. #include <cmath>
  4. struct cpx
  5. {
  6. cpx(){}
  7. cpx(double aa):a(aa){}
  8. cpx(double aa, double bb):a(aa),b(bb){}
  9. double a;
  10. double b;
  11. double modsq(void) const
  12. {
  13. return a * a + b * b;
  14. }
  15. cpx bar(void) const
  16. {
  17. return cpx(a, -b);
  18. }
  19. };
  20. cpx operator +(cpx a, cpx b)
  21. {
  22. return cpx(a.a + b.a, a.b + b.b);
  23. }
  24. cpx operator *(cpx a, cpx b)
  25. {
  26. return cpx(a.a * b.a - a.b * b.b, a.a * b.b + a.b * b.a);
  27. }
  28. cpx operator /(cpx a, cpx b)
  29. {
  30. cpx r = a * b.bar();
  31. return cpx(r.a / b.modsq(), r.b / b.modsq());
  32. }
  33. cpx EXP(double theta)
  34. {
  35. return cpx(cos(theta),sin(theta));
  36. }
  37. const double two_pi = 4 * acos(0);
  38. // in: input array
  39. // out: output array
  40. // step: {SET TO 1} (used internally)
  41. // size: length of the input/output {MUST BE A POWER OF 2}
  42. // dir: either plus or minus one (direction of the FFT)
  43. // RESULT: out[k] = \sum_{j=0}^{size - 1} in[j] * exp(dir * 2pi * i * j * k / size)
  44. void FFT(cpx *in, cpx *out, int step, int size, int dir)
  45. {
  46. if(size < 1) return;
  47. if(size == 1)
  48. {
  49. out[0] = in[0];
  50. return;
  51. }
  52. FFT(in, out, step * 2, size / 2, dir);
  53. FFT(in + step, out + size / 2, step * 2, size / 2, dir);
  54. for(int i = 0 ; i < size / 2 ; i++)
  55. {
  56. cpx even = out[i];
  57. cpx odd = out[i + size / 2];
  58. out[i] = even + EXP(dir * two_pi * i / size) * odd;
  59. out[i + size / 2] = even + EXP(dir * two_pi * (i + size / 2) / size) * odd;
  60. }
  61. }
  62. // Usage:
  63. // f[0...N-1] and g[0..N-1] are numbers
  64. // Want to compute the convolution h, defined by
  65. // h[n] = sum of f[k]g[n-k] (k = 0, ..., N-1).
  66. // Here, the index is cyclic; f[-1] = f[N-1], f[-2] = f[N-2], etc.
  67. // Let F[0...N-1] be FFT(f), and similarly, define G and H.
  68. // The convolution theorem says H[n] = F[n]G[n] (element-wise product).
  69. // To compute h[] in O(N log N) time, do the following:
  70. // 1. Compute F and G (pass dir = 1 as the argument).
  71. // 2. Get H by element-wise multiplying F and G.
  72. // 3. Get h by taking the inverse FFT (use dir = -1 as the argument)
  73. // and *dividing by N*. DO NOT FORGET THIS SCALING FACTOR.
  74. int main(void)
  75. {
  76. printf("If rows come in identical pairs, then everything works.\n");
  77. cpx a[8] = {0, 1, cpx(1,3), cpx(0,5), 1, 0, 2, 0};
  78. cpx b[8] = {1, cpx(0,-2), cpx(0,1), 3, -1, -3, 1, -2};
  79. cpx A[8];
  80. cpx B[8];
  81. FFT(a, A, 1, 8, 1);
  82. FFT(b, B, 1, 8, 1);
  83. for(int i = 0 ; i < 8 ; i++)
  84. {
  85. printf("%7.2lf%7.2lf", A[i].a, A[i].b);
  86. }
  87. printf("\n");
  88. for(int i = 0 ; i < 8 ; i++)
  89. {
  90. cpx Ai(0,0);
  91. for(int j = 0 ; j < 8 ; j++)
  92. {
  93. Ai = Ai + a[j] * EXP(j * i * two_pi / 8);
  94. }
  95. printf("%7.2lf%7.2lf", Ai.a, Ai.b);
  96. }
  97. printf("\n");
  98. cpx AB[8];
  99. for(int i = 0 ; i < 8 ; i++)
  100. AB[i] = A[i] * B[i];
  101. cpx aconvb[8];
  102. FFT(AB, aconvb, 1, 8, -1);
  103. for(int i = 0 ; i < 8 ; i++)
  104. aconvb[i] = aconvb[i] / 8;
  105. for(int i = 0 ; i < 8 ; i++)
  106. {
  107. printf("%7.2lf%7.2lf", aconvb[i].a, aconvb[i].b);
  108. }
  109. printf("\n");
  110. for(int i = 0 ; i < 8 ; i++)
  111. {
  112. cpx aconvbi(0,0);
  113. for(int j = 0 ; j < 8 ; j++)
  114. {
  115. aconvbi = aconvbi + a[j] * b[(8 + i - j) % 8];
  116. }
  117. printf("%7.2lf%7.2lf", aconvbi.a, aconvbi.b);
  118. }
  119. printf("\n");
  120. return 0;
  121. }