"""Tests that the results generated by Brian are correct.

This can be done by testing against alebraically solvable test cases,
or against known correct behaviour.
"""

import unittest

from brian import *
from brian.stdunits import *
from brian.globalprefs import get_global_preference
from brian.log import *

class TestSequenceFunctions(unittest.TestCase):
    def set_up(self):
        pass
    def testfromtutorial1c(self):
        '''Tests a behaviour from Tutorial 1c
        
        Solving the differential equation gives:
        
            V = El + (Vr-El) exp (-t/tau)
        
        Setting V=Vt at time t gives:
        
            t = tau log( (Vr-El) / (Vt-El) )
        
        If the simulator runs for time T, and fires a spike immediately
        at the beginning of the run it will then generate n spikes,
        where:
        
            n = [T/t] + 1
        
        If you have m neurons all doing the same thing, you get nm
        spikes. This calculation with the parameters above gives:
        
            t = 48.0 ms
            n = 21
            nm = 840
        
        As predicted.
        '''
        tau = 20*msecond        # membrane time constant
        Vt  =-50*mvolt          # spike threshold
        Vr  =-60*mvolt          # reset value
        El  =-49*mvolt          # resting potential (same as the reset)
        dV = 'dV/dt = -(V-El)/tau : volt # membrane potential'
        model = Model(equation=dV,threshold=Vt,reset=Vr)
        G = NeuronGroup(N=40,model=model)
        G.V = El
        M = SpikeMonitor(G)
        run(1*second)
        self.assert_(M.nspikes==840)
    def testexponentialcurrent(self):
        '''Tests whether an exponential current works as predicted
        
        From Tutorial 2b.
        
        The scheme we implement is the following diffential equations:
        
            | taum dV/dt = -V + ge - gi
            | taue dge/dt = -ge
            | taui dgi/dt = -gi
        
        An excitatory neuron connects to state ge, and an inhibitory neuron connects
        to state gi. When an excitatory spike arrives, ge instantaneously increases,
        then decays exponentially. Consequently, V will initially but continuously
        rise and then fall. Solving these equations, if V(0)=0, ge(0)=g0 corresponding
        to an excitatory spike arriving at time 0, and gi(0)=0 then:
        
            | gi = 0    
            | ge = g0 exp(-t/taue)
            | V = (exp(-t/taum) - exp(-t/taue)) taue g0 / (taum-taue)
        '''
        taum = 20*ms
        taue =  1*ms
        taui = 10*ms
        Vt   = 10*mV
        Vr   =  0*mV
        
        model = Model(equations='''
                       dV/dt = (-V+ge-gi)/taum : volt
                       dge/dt = -ge/taue : volt
                       dgi/dt = -gi/taui : volt
                       ''',
                       threshold=Vt,
                       reset=Vr)
        
        spiketimes = [(0,0*ms)]
        
        G1 = SpikeGeneratorGroup(2,spiketimes)
        G2 = NeuronGroup(N=1,model=model)
        G2.V = Vr
        
        C1 = Connection(G1,G2,'ge')
        C2 = Connection(G1,G2,'gi')
        
        C1[0,0] = 3*mV
        C2[1,0] = 3*mV
        
        Mv  = StateMonitor(G2,'V',record=True)
        Mge = StateMonitor(G2,'ge',record=True)
        Mgi = StateMonitor(G2,'gi',record=True)
        
        run(100*ms)
        
        t = Mv.times
        Vpredicted = (exp(-t/taum) - exp(-t/taue))*taue*(3*mV) / (taum-taue)
    
        Vdiff = abs(Vpredicted - Mv[0])
        
        self.assert_(max(Vdiff)<0.00001*volt)

    def testepsp(self):
        """Tests whether an alpha function EPSP works algebraically.
        
        The expected behaviour of the network below is that it should solve the
        following differential equation:
        
        taum   dV/dt = -V + x
        taupsp dx/dt = -x + y
        taupsp dy/dt = -y
                V(0) = 0 volt
                x(0) = 0 volt
                y(0) = y0 volt
            
        This gives the following analytical solution for V (computed with Mathematica):
        
        V(t) = (E^(-(t/taum) - t/taupsp)*(-(E^(t/taum)*t*taum) + 
           E^(t/taum)*t*taupsp - E^(t/taum)*taum*taupsp + 
           E^(t/taupsp)*taum*taupsp)*y0)/(taum - taupsp)^2
        
        This doesn't have an analytical solution for the maximum value of V, but the
        following numerical value was computed with the analytic formula:
        
                Vmax = 0.136889 mvolt  (accurate to that many sig figs)
        at time    t = 1.69735 ms (accurate to +/- 0.00001ms)
        
        The Brian network consists of two neurons, one governed by the differential
        equations given above, the other fires a single spike at time t=0 and is
        connected to the first
        """
        clock = Clock(dt=0.1*ms)        
        expected_vmax = 0.136889*mvolt
        expected_vmaxtime = 1.69735*msecond
        desired_vmaxaccuracy = 0.001*mvolt
        desired_vmaxtimeaccuracy = max(clock.dt,0.00001*ms)
        taum = 10*ms
        taupsp = 0.325*ms
        y0 = 4.86 * mV
        P = NeuronGroup(N=1, model='''
                      dV/dt = (-V+x)*(1./taum) : volt
                      dx/dt = (-x+y)*(1./taupsp) : volt
                      dy/dt = -y*(1./taupsp) : volt
                      ''',
                      threshold=100*mV,reset=0*mV)
        Pinit = SpikeGeneratorGroup(1,[(0,0*ms)])
        C = Connection(Pinit,P,'y')
        C.connect_full(Pinit,P,y0)
        M = StateMonitor(P,'V',record=0)
        run(10*ms)
        V = M[0]
        Vmax = 0*volt
        Vi = 0
        for i in range(len(V)):
            if V[i]>Vmax:
                Vmax=V[i]
                Vi = i
        Vmaxtime = M.times[Vi]
        self.assert_(abs(Vmax-expected_vmax)<desired_vmaxaccuracy)
        self.assert_(abs(Vmaxtime-expected_vmaxtime)<desired_vmaxtimeaccuracy)

def run_test():
    log_level_error()
    import inspect, brian, sys
    print '****************************************'
    print 'Running the behaviour verification tests'
    print '****************************************'
    print
    print 'Running from directory:', inspect.getsourcefile(brian)
    suite = unittest.TestLoader().loadTestsFromTestCase(TestSequenceFunctions)
    return unittest.TextTestRunner(stream=sys.stdout,verbosity=2).run(suite).wasSuccessful()    
    
if __name__=="__main__":
    run_test()