#=======================================================================
# Direct and inverse FFT examples (including frequency analysis,
# spectrogram generation and "brutal" band pass filtering)
#
# Cesare Brizio, 3 January 2024
#
# Based on several examples, no creative merit on my part except
# reconciling a few variable / array names among the examples.
#
#=======================================================================
import numpy as np
from matplotlib import pyplot as plt
from scipy.io import wavfile
from scipy.fft import fft, fftfreq
from scipy.fft import rfft, rfftfreq
from scipy.fft import irfft
from scipy import signal


SAMPLE_RATE = 44100  #Samples per second
DURATION = 5  #Seconds
FREQUENCY_TEST = 2 #Hertz

def generate_sine_wave(freq, sample_rate, duration):
    x = np.linspace(0, duration, sample_rate * duration, endpoint=False)
    frequencies = x * freq
    #2pi because np.sin takes radians
    y = np.sin((2 * np.pi) * frequencies)
    return x, y


#======================================================
#======================================================
#
# GENERATE AND PLOT A SINE WAVE
#
#======================================================
#======================================================

#Generate a 2 hertz sine wave that lasts for 5 seconds
x, y = generate_sine_wave(FREQUENCY_TEST, SAMPLE_RATE, DURATION)
plt.rcParams['figure.figsize'] = [12, 7]
plt.plot(x, y)
# displaying the title 
plt.title("2 Hz sine Wave generated by generate_sine_wave()\nSample rate = "+str(SAMPLE_RATE)+" Hz, Duration = "+str(DURATION)+" sec\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('SIN value')
plt.xlabel('sec')
plt.show()


#======================================================
#======================================================
#
# GENERATE, NORMALIZE, PLOT AND SAVE AS A WAVEFILE  
# THE SUM OF TWO SINE WAVES (WILL BE BAND-STOP-FILTERED
# IN SUBSEQUENT STEPS)
#
#======================================================
#======================================================
FREQUENCY_GOOD = 400
FREQUENCY_NOISE = 4000
_, nice_tone = generate_sine_wave(FREQUENCY_GOOD, SAMPLE_RATE, DURATION)
_, noise_tone = generate_sine_wave(FREQUENCY_NOISE, SAMPLE_RATE, DURATION)
noise_tone = noise_tone * 0.3

mixed_tone = nice_tone + noise_tone

normalized_tone = np.int16((mixed_tone / mixed_tone.max()) * 32767)

plt.rcParams['figure.figsize'] = [12, 7]
LIST_SLICE = 1000
DURATION_LIST_SLICE = round((1/SAMPLE_RATE)*LIST_SLICE,3)
FREQ_GOOD_CYCLES_PER_SLICE = round(DURATION_LIST_SLICE*FREQUENCY_GOOD,3)
plt.plot(normalized_tone[:LIST_SLICE])
# displaying the title 
plt.title("Normalized tone from the sum of 2 sine waves (will be saved as mysinewave.wav)\n\u00ABGood\u00BB = "+str(FREQUENCY_GOOD)+" Hz, \u00ABNoise\u00BB = "+str(FREQUENCY_NOISE)+" Hz, Duration = first "+str(LIST_SLICE)+" entries at a sample rate of "+str(SAMPLE_RATE)+" Hz\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('Samples')
plt.xlabel('Entries at a sample rate of '+str(SAMPLE_RATE)+' Hz - 1000 entries = '+str(DURATION_LIST_SLICE)+' sec - 1000 entries = '+str(FREQ_GOOD_CYCLES_PER_SLICE)+' cycles at '+str(FREQUENCY_GOOD)+' Hz')
plt.show()

#Remember SAMPLE_RATE = 44100 Hz is our playback rate
wavfile.write("C:\mysinewave.wav", SAMPLE_RATE, normalized_tone)


#======================================================
#======================================================
#
# PERFORM AND DISPLAY DIRECT FFT (FROM TIME DOMAIN TO 
# FREQUENCY DOMAIN) OF THE NORMALIZED COMBINATION
# OF TWO SINE WAVES (WILL BE BAND-STOP-FILTERED
# IN SUBSEQUENT STEPS). THE FFT WILL INCLUDE THE ENTIRE
# COMPLEX INPUT, INCLUDING NEGATIVE FREQUENCIES
#
#======================================================
#======================================================

#Number of samples in normalized_tone
N = SAMPLE_RATE * DURATION

yf = fft(normalized_tone)
xf = fftfreq(N, 1 / SAMPLE_RATE)

plt.rcParams['figure.figsize'] = [12, 7]
plt.plot(xf, np.abs(yf))
# displaying the title 
plt.title("Direct FFT of the normalized tone\nFull-Range ( -"+str(SAMPLE_RATE/2)+" Hz to +"+str(SAMPLE_RATE/2)+" Hz) Pressure-Frequency Analysis\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('FFT value (sound pressure)')
plt.xlabel('Frequency')
plt.show()



#======================================================
#======================================================
#
# PERFORM AND DISPLAY DIRECT FFT (FROM TIME DOMAIN TO 
# FREQUENCY DOMAIN) OF THE NORMALIZED COMBINATION
# OF TWO SINE WAVES (WILL BE BAND-STOP-FILTERED
# IN SUBSEQUENT STEPS). THE FFT WILL INCLUDE ONLY
# REAL INPUT (EXCLUDING NEGATIVE FREQUENCIES)
#
#======================================================
#======================================================

#Note the extra 'r' at the front
yf = rfft(normalized_tone)
xf = rfftfreq(N, 1 / SAMPLE_RATE)
plt.rcParams['figure.figsize'] = [12, 7]
plt.plot(xf, np.abs(yf))
# displaying the title 
plt.title("Direct FFT of the normalized tone (REAL INPUT, does not compute the negative frequency terms)\nRange (0 Hz to +"+str(SAMPLE_RATE/2)+" Hz) Pressure-Frequency Analysis\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('FFT value (sound pressure)')
plt.xlabel('Frequency')
plt.show()


#======================================================
#======================================================
#
# READ THE WAVEFILE AND, USING THE WELCH() FUNCTION  
# FROM THE SIGNAL PACKAGE, PLOT THE POWER SPECTRUM
# IN dBFS.
# Welch’s method computes an estimate of the power 
# spectral density by dividing the data into 
# overlapping segments, computing a modified 
# periodogram for each segment and averaging the 
# periodograms.
#
#======================================================
#======================================================

segment_size = 512

fs, x = wavfile.read('C:\mysinewave.wav')
x = x / 32768.0  # scale signal to [-1.0 .. 1.0]

noverlap = segment_size / 2
f, Pxx = signal.welch(x,                        # signal
                      fs=fs,                    # sample rate
                      nperseg=segment_size,     # segment size
                      window='hann',            # window type to use e.g.hamming or hann
                      nfft=segment_size,        # num. of samples in FFT
                      detrend=False,            # remove DC part
                      scaling='spectrum',       # return power spectrum [V^2]
                      noverlap=noverlap)        # overlap between segments

# set 0 dB to energy of sine wave with maximum amplitude
ref = (1/np.sqrt(2)**2)   # simply 0.5 ;)
p = 10 * np.log10(Pxx/ref)

plt.rcParams['figure.figsize'] = [12, 7]

fill_to = -150 * (np.ones_like(p))  # anything below -150dB is irrelevant
plt.fill_between(f, p, fill_to )
plt.xlim([f[2], f[-1]])
plt.ylim([-150, 6])
# plt.xscale('log')   # uncomment if you want log scale on x-axis
plt.grid(True)
# displaying the title 
plt.title("Direct FFT of the normalized tone (REAL INPUT, does not compute the negative frequency terms)\nRange (0 Hz to +"+str(SAMPLE_RATE/2)+" Hz) Pressure-Frequency Analysis\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.xlabel('Frequency, Hz')
plt.ylabel('Power spectrum, dBFS')
plt.show()



#======================================================
#======================================================
#
# READ THE WAVEFILE AND, USING THE SPECTROGRAM() 
# FUNCTION FROM THE SIGNAL PACKAGE, PLOT THE 
# TIME/FREQUENCY SPECTROGRAM.
#
#======================================================
#======================================================

f, t, Sxx = signal.spectrogram(x, SAMPLE_RATE)
plt.pcolormesh(t, f, Sxx, shading='nearest') # shading may be flat, nearest, gouraud or auto
plt.rcParams['figure.figsize'] = [12, 7]
# displaying the title 
plt.title("Spectrogram of the normalized tone (REAL INPUT, does not compute the negative frequency terms)\nRange (0 Hz to +"+str(SAMPLE_RATE/2)+" Hz) Pressure-Frequency Analysis\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.show()


#======================================================
#======================================================
#
# BAND-STOP FILTERING (BRUTAL, NO WINDOWING FUNCTION)
#
# Above, scipy.fft.rfftfreq() was used to return the
# Discrete Fourier Transform sample frequencies.
# The xf float array, returned by scipy.fft.rfftfreq()
# contains the frequency bin centers in cycles per unit 
# of the sample spacing (with zero at the start). For 
# instance, if the sample spacing is in seconds, then 
# the frequency unit is cycles/second.
#
# The yf float array was returned by scipy.fft.rfft(),
# a function that computes the 1-D n-point discrete 
# Fourier Transform (DFT) for real input.
#
# By setting to zero the yf values corresponding to 
# a given xf bin, we are brutally silencing the range 
# of frequencies subsumed by that bin.
#
# A more sophisticated approach to filtering, based
# on windowing functions such as the one exemplified
# above. A brutal surrogate of the windowing function
# is setting to zero also a few bins adjacent to the 
# target bin.
#
#======================================================
#======================================================

#The maximum frequency is half the sample rate
points_per_freq = len(xf) / (SAMPLE_RATE / 2)

#Our target (noise) frequency is 4000 Hz
target_idx = int(points_per_freq * 4000)
#Also a small range of adjacent bins is set to zero
yf[target_idx - 1 : target_idx + 2] = 0

plt.rcParams['figure.figsize'] = [12, 7]
plt.plot(xf, np.abs(yf))
# displaying the title 
plt.title("Effects of the \u00ABbrutal\u00BB band stop filtering of the normalized tone \nRange (0 Hz to +"+str(SAMPLE_RATE/2)+" Hz) Pressure-Frequency Analysis\nClose window for next step", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('FFT value (sound pressure)')
plt.xlabel('Frequency')
plt.show()


#======================================================
#======================================================
#
# RESTORE THE SINE WAVE (NOW NOT INCLUDING NOISE) BY
# PERFORMING AN INVERSE FFT (REAL VALUES ONLY), DISPLAY
# THE CLEAN WAVE AND SAVE IT AS A WAVEFILE.
#
#======================================================
#======================================================

new_sig = irfft(yf)

plt.rcParams['figure.figsize'] = [12, 7]
plt.plot(new_sig[:LIST_SLICE])
# displaying the title 
plt.title("Normalized tone after the band-stop filtering of the noise wave (will be saved as clean.wav)\nDuration = first "+str(LIST_SLICE)+" entries at a sample rate of "+str(SAMPLE_RATE)+" Hz\nClose window to end program", 
          fontsize=14,
          fontweight="bold", 
          color="green") 
plt.ylabel('Samples')
plt.xlabel('Entries at a sample rate of '+str(SAMPLE_RATE)+' Hz - 1000 entries = '+str(DURATION_LIST_SLICE)+' sec - 1000 entries = '+str(FREQ_GOOD_CYCLES_PER_SLICE)+' cycles at '+str(FREQUENCY_GOOD)+' Hz')
plt.show()

norm_new_sig = np.int16(new_sig * (32767 / new_sig.max()))


wavfile.write("C:\clean.wav", SAMPLE_RATE, norm_new_sig)