#! /usr/bin/env python

def importWave(filename):
    '''
    Return a tuple containing the framerate and audio signal samples from
    filename. 

    importWave(filename) -> (samp_freq, signals)

    Input:
        filename - a string containing the name of a wave file to parse
    Output:
        samp_freq - the sampling frequency of the wave file
        signals   - a tuple containing all channels of the signal
    Errors:
        'WaveIO Channel Error'        - if filename's channels is not 1 or 2
        'WaveIO Sample Width Error'   - if filename doesn't have 8, 16 or 32
                                        bit samples
        wave or struct module errors (i.e. if filename is compressed)
    '''

    import wave
    import struct

    #open wave file, get parameters and check for errors
    #if an error is raised, close the file and throw the error
    f = wave.open(filename)
    
    samp_width = f.getsampwidth()
    if not (samp_width in [1,2,4]):
        f.close()
        raise 'WaveIO Sample WidthError',\
              'Sample width of %s unparseable' % sampwidth
    samp_channels = f.getnchannels()
    samp_freq = f.getframerate()
    samp_length = f.getnframes()

    conv_code = packing_code( samp_width )

    frames = f.readframes(samp_length)
    f.close()
    signals = struct.unpack(("%d%c") % (2*samp_length,conv_code),frames)
    scale = 1./2**(8*samp_width-1.)
    float_signals = [scale*s for s in signals]
    channels = []
    for n in range(samp_channels):
        channels.append(float_signals[n::samp_channels])

    return (samp_freq, channels)

def exportWave(filename, samp_rate, signals, samp_width=2, scale=True):
    '''
    Return True on successful write of audio data in signal to filename.  If
    signal2 is passed and the same length as signal, a stereo file is written.

    Inputs:
        filename   - name of wave file to be written
        samp_rate  - the sampling rate
        signals    - audio signal to be written, as a tuple of channels.
        samp_width - the sample size to be written, defaults to two bytes
    Outputs:
        True on success
    Raises:
        Possibly errors from the wave module.
        'WaveIO Signal Error'         - if len(signal) != len(signal2) w/
                                        passed signal2
        'WaveIO Sample Width Error'   - if samp_width not in [1,2,4]
        'WaveIO Sample Freq Error'    - if samp_rate <= 0 or not an int
    '''

    import wave, struct

    samp_channels = len(signals)

    #smin = min([min(s) for s in signals])
    #smax = max([max(s) for s in signals])
    #print "min",smin,"max",smax

    samp_length = len(signals[0])
    for s in signals[1:]:
        if len(s) != samp_length:
            raise 'WaveIO Signal Error',\
                  'different sample lengths passed to waveio.exportWave'

    if not (samp_width in [1,2,4]):
        raise 'WaveIO Sample Width Error',\
              'Sample width of %s not valid - need 1, 2, or 4' % samp_width

    if (samp_rate <= 0) or (samp_rate != int(samp_rate)):
        raise 'WaveIO Sample Freq Error',\
              'Sample frequency of %s invalid' % samp_rate

    scaled_signals = scale_signals(signals,samp_width,scale)

    f = wave.open(filename,'w')
    conv_code = packing_code( samp_width )
    f.setnchannels( samp_channels )
    f.setframerate( samp_rate )
    f.setsampwidth( samp_width )

    # splice the signals together
    spl_sig = samp_channels*samp_length*[0]
    for i,s in enumerate(scaled_signals):
        spl_sig[i::samp_channels] = s

    #print "min",min(spl_sig),"max",max(spl_sig)
    packed = struct.pack(("%d%c") % (len(spl_sig),conv_code),*spl_sig)
    f.writeframes(packed)
    
    return True

def packing_code(samp_width):
    """
    Return the code associated struct.[un]pack for the specificied number
    of bytes.
    Inputs:
        samp_width - number of bytes to [un]pack
    Outputs:
        the conversion code used with stuct.[un]pack
    Raises:
        'WaveIO Packing Error' - if the number of bytes not in 1, 2, 4
    """
    if samp_width == 1:               #8 bits are unsigned, 16 & 32 signed
        return 'B'     #unsiged 8 bits
    elif samp_width == 2:
        return 'h'     #signed 16 bits
    elif samp_width == 4:
        return 'i'     #signed 32 bits

    raise 'WaveIO Packing Error','Not able to parse %s bytes' % samp_width

def scale_signals(signals, samp_width, scale):
    """
    Take an arbitrary signal, sig, stored as a list and return a signal
    that is integral and can be compressed in bt bits.  The signal is
    recentered about zero.
    Input:
        sig    - signal in a list
        samp_width - bits to express signal in
        scale  - whether to reset the signal amplitude 
    Output:
        a list containing the linearly scaled signal ranging from
        0 to 2**(bt)
    Raises:
        rounding errors
        non-number cast errors
    """

    scaler = 2**(8*samp_width-1.)
    eps = 1./scaler

    if scale:
        posmax = max([max(s) for s in signals])
        negmax = max([max([-x for x in s]) for s in signals])
        if posmax>=negmax:
            maxval = posmax/(1.-eps)
        else:
            maxval = negmax
    else:
        maxval = 1

    scaled_signals = [[int(scaler*samp) for samp in s] for s in signals]
    if samp_width==1:
        # need to make data unsigned
        scaled_signals = [[x+127 for x in s] for s in scaled_signals]

    sc = scaler/maxval

    scaled = [[int(sc*x) for x in s] for s in signals]

    return scaled
