Python Forum

Full Version: odeint to solve Schrodinger equation
You're currently viewing a stripped down version of our content. View the full version with proper formatting.
Pages: 1 2
Hi, Im trying to solve the Schrodinger equation. I am basing myself on this site but in altering the code odeint is giving me the wrong results. the functions find_all_zeroes(x,y) and find_analytic_energies(en) are supposed to give me the the same results but they are vastly different. This is the altered code I am using for the part in question. Can someone tell me what I am doing wrong?

from pylab import *
from scipy.integrate import odeint
from scipy.optimize import brentq

a=1
B=4
L= B+a
Vmax= 50
Vpot = False
def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    if Vpot==True:
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val)
             Xpotential.append(x)
    return val
 
def SE(psi, x):
    """
    Returns derivatives for the 1D schrodinger eq.
    Requires global value E to be set somewhere. State0 is first derivative of the
    wave function psi, and state1 is its second derivative.
    """
    state0 = psi[1]
    state1 = 2.0*(V(x) - E)*psi[0]
    return array([state0, state1])
 
def Wave_function(energy):
    """
    Calculates wave function psi for the given value
    of energy E and returns value at point b
    """
    global psi
    global E
    E = energy
    psi = odeint(SE, psi0, x)
    return psi[-1,0]
 
def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    s = sign(y)
    for i in range(len(y)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes
 
def find_analytic_energies(en):
    """
    Calculates Energy values for the finite square well using analytical
    model (Griffiths, Introduction to Quantum Mechanics, 1st edition, page 62.)
    """
    z = sqrt(2*en)
    z0 = sqrt(2*Vo)
    z_zeroes = []
    f_sym = lambda z: tan(z)-sqrt((z0/z)**2-1)      # Formula 2.138, symmetrical case
    f_asym = lambda z: -1/tan(z)-sqrt((z0/z)**2-1)  # Formula 2.138, antisymmetrical case
 
    # first find the zeroes for the symmetrical case
    s = sign(f_sym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_sym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("Energies from the analyitical model are: ")
    print ("Symmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=(2n-1)pi/2 solutions cause that's where tan(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))
    # Now for the asymmetrical
    z_zeroes = []
    s = sign(f_asym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_asym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("(Antisymmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=npi solutions cause that's where ctg(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))
 
N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis
 
def main():
    # main program        
 
    en = linspace(0, Vo, 1000000)   # vector of energies where we look for the stable states
 
    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0 
 
    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
    # Print energies of each bound state from the analytical model
    find_analytic_energies(en)   
 
    # Plot wave function values at b vs energy vector
    figure()
    plot(en/Vo,psi_b)
    title('Values of the $\Psi(b)$ vs. Energy')
    xlabel('Energy, $E/V_0$')
    ylabel('$\Psi(x = b)$', rotation='horizontal')
    for E in E_zeroes:
        plot(E/Vo, [0], 'go')
        annotate("E = %.2f"%E, xy = (E/Vo, 0), xytext=(E/Vo, 30))
    grid()
 
    # Plot the wavefunctions for first 4 eigenstates
    figure(2)
    for E in E_zeroes[0:4]:
        Wave_function(E)
        plot(x, psi[:,0], label="E = %.2f"%E)
    legend(loc="upper right")
    title('Wave function')
    xlabel('x, $x/L$')
    ylabel('$\Psi(x)$', rotation='horizontal', fontsize = 15)
    grid()

    figure(3)
    pot =[]
    for i in x:
        pot.append(V(i))
    plot(x,pot)
    show()
if __name__ == "__main__":
    main()
That code needs some serious refactoring. I've been digging into it for a few... twenty minutes... -ish... and I haven't sussed out the deal with E. That global variable doesn't appear do actually do anything in the code. Every subsequent instance when a variable E is used, the variable is instantiated in a for loop, so the global isn't being used at all, it seems. Plus, it gets set as a global from a function - very bad practice.

Can you provide the expected results for the problematic functions when given known arguments?
the function find_analytic_energies(en) finds the correct values, the list E_zeroes (populated in line 109) should contain the same values, problem is that I need the function Wave function to work for other things down the line. If I try to use the values generated via ind_analytic_energies(en) with Wave function(energy) I do not get the expected results, ie the wave isn't at zero on both ends of V(x).
Okay, I'm going to work on cleaning up the code. There are some peculiarities that I will need explained by someone who understands the calculations at work. So, I will post my marked up version later with lots of line comments. If you could clarify those lines, I would appreciate it.
thanks, I'll try to answer everything, I sould be able to since I did the math by hand. I also asked here, though it concerns it's more concerned with the method used and not code debugging.
I uploaded the file to my GitHub. I added some line comments with questions about the code. Right now, I need clarification on what each function and variable truly is. The names the original author selected are not descriptive so I don't know if V() is for Vector() or Velocity() or any number of other concepts beginning with a "v".

For the first round of refactoring, I'm going to focus on lateral changes to correct some of the bad practices (such as setting globals from inside a function). That should not change anything operationally.
I think I answered all of the questions in the code
V() would work better as a closure or class method Needs a more descriptive name; what is V? Vector? Velocity? V is the potential function for the system, eg imagin ammonia molecule, it is shaped like a pyramid with three hydrogen at the base and nitrogen at the top. the potential Vo represents the area with the hydrogen atoms. the area where V(x)= 0 are the places that you'd probably find N. when Wave_function is called and all the zeros are found i have the energy with that energy I can plot the wave function onto a graph, but like the wave function is not being solved correctly because when I plot the wave using the correct energy that is found by find_analytic_energies, at x =-a-B the wave function is zero( which is correct) but at x = L the wave function is not zero, like this picture:
[Image: Infinite+Potential+Well+…+bottom+line.jpg]
that leads me to conclude that the problem is with the solution of the system of ordinary differential equations. I can't tell if I followed the advice given to me on scicomp
"If you really want to deal with an infinite potential well, then you should set b=L
and enforce the boundary condition ψ(b)=0. In this case it also makes sense
to start shooting at x=−b, with ψ(−b)=0and ψ′(b)nonzero." - LonelyProf

from pylab import *
from scipy.integrate import odeint
from scipy.optimize import brentq

a=1 # Never changes. Constant? yes
B=4 # Never changes. Constant? yes
L= B+a # Never changes. Constant? yes these tree are here more for me to easily change
Vmax= 50
Vpot = False


# V() would work better as a closure or class method
# Needs a more descriptive name; what is V? Vector? Velocity?
# V is the potential function for the sistem, eg imagin ammonia molecule,
# it is shaped like a pyramid
 with three hydrogen at the baseand nitrogen
# at the top. the potential Vo represents the area with the the hydrogen
# atoms. the area where V(x)= 0 are the places that you'd probably find N.
# when Wave_function is called and all the zeros are found i have the energy
# with that energy I can plot the wave function onto a graph, but like the
# wave function is not being solved correctly because when I plot  the wave
# using the correct energy that is found by find_analytic_energies, at x =-a-B
# the wave function is zero( whichis correct) but at x = L the wave function is not zero
# that leads me to conclude that the problem is with the solution of the system of ordinary
# differential equations. I can't tell if I followed the advice given to me on scicomp
# "If you really want to deal with an infinite potential well, then you should set  b=L
# and enforce the boundary condition ψ(b)=0. In this case it also makes sense
# to start shooting at x=−b, with ψ(−b)=0and ψ′(b)nonzero." - LonelyProf

def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    # This conditional can never be entered     #### this is here for parts of the code that come later on
    ##                                               I tried reducing things here to a min. for the problem to be clearer
    if Vpot==True: # never the case, Vpot does not change
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val) # sequence does not exist
             Xpotential.append(x) # sequence does not exist
    return val

def SE(psi, x):
    """
    Returns derivatives for the 1D schrodinger eq.
    Requires global value E to be set somewhere. State0 is first derivative of the
    wave function psi, and state1 is its second derivative.
    """
    state0 = psi[1]
    state1 = 2.0*(V(x) - E)*psi[0]
    return array([state0, state1])

def Wave_function(energy):
    """
    Calculates wave function psi for the given value
    of energy E and returns value at point b
    """
    global psi # Functions should not call global variables
    global E # Functions should not call global variables
    E = energy # Functions should not set global variables from within
    psi = odeint(SE, psi0, x) # Functions should not set global variables from within
    return psi[-1,0]

def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    s = sign(y)
    for i in range(len(y)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes

def find_analytic_energies(en):
    """
    Calculates Energy values for the finite square well using analytical
    model (Griffiths, Introduction to Quantum Mechanics, 1st edition, page 62.)
    """
    z = sqrt(2*en)
    z0 = sqrt(2*Vo)
    z_zeroes = []
    f_sym = lambda z: tan(z)-sqrt((z0/z)**2-1)      # Formula 2.138, symmetrical case
    f_asym = lambda z: -1/tan(z)-sqrt((z0/z)**2-1)  # Formula 2.138, antisymmetrical case

    # first find the zeroes for the symmetrical case
    s = sign(f_sym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_sym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("Energies from the analyitical model are: ")
    print ("Symmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=(2n-1)pi/2 solutions cause that's where tan(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))
    # Now for the asymmetrical
    z_zeroes = []
    s = sign(f_asym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_asym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("(Antisymmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=npi solutions cause that's where ctg(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))

N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis

def main():
    # main program

    en = linspace(0, Vo, 1000000)   # vector of energies where we look for the stable states

    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0

    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
    # Print energies of each bound state from the analytical model
    find_analytic_energies(en)

    # Plot wave function values at b vs energy vector
    figure()
    plot(en/Vo,psi_b)
    title('Values of the $\Psi(b)$ vs. Energy')
    xlabel('Energy, $E/V_0$')
    ylabel('$\Psi(x = b)$', rotation='horizontal')
    for E in E_zeroes:
        plot(E/Vo, [0], 'go')
        annotate("E = %.2f"%E, xy = (E/Vo, 0), xytext=(E/Vo, 30))
    grid()

    # Plot the wavefunctions for first 4 eigenstates
    figure(2)
    for E in E_zeroes[0:4]:
        Wave_function(E)
        plot(x, psi[:,0], label="E = %.2f"%E)
    legend(loc="upper right")
    title('Wave function')
    xlabel('x, $x/L$')
    ylabel('$\Psi(x)$', rotation='horizontal', fontsize = 15)
    grid()

    figure(3)
    pot =[]
    for i in x:
        pot.append(V(i))
    plot(x,pot)
    show()
if __name__ == "__main__":
    main()
I'm currently trying to swap odeint out for scipy.integrate.solve_ivp because on odeint's page it says this:
Quote:Note
For new code, use scipy.integrate.solve_ivp to solve a differential equation.


I tried swapping, but I keep getting this error:
Error:
...Programs\Python\Python36-32\lib\site-packages\scipy\integrate\_ivp\rk.py", line 67, in rk_step K[0] = f ValueError: could not broadcast input array from shape (2,2) into shape (2)
this is what wavefunction looks like now:
from pylab import *
from scipy.integrate import solve_ivp
from scipy.optimize import brentq

a=1
B=4
L= B+a
Vmax= 50
Vpot = False

N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis
def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    if Vpot==True:
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val)
             Xpotential.append(x)
    return val
 
def SE(x, p):
    state0 = psi[1]
    state1 = 1.0*(V(x) - E)*psi[0]
    return array([state0, state1])
 
def Wave_function(energy):
    E = energy
    #        odeint(func, y0, t)
    #     solve_ivp(fun, t_span, y0)
    psi = solve_ivp(SE, [-B-a, L], np.array(psi0))
    print(psi)
    return psi[-1,0]
 
def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    s = sign(y)
    for i in range(len(y)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes
 
def main():
    # main program        
 
    en = linspace(0, Vo, 100)   # vector of energies where we look for the stable states
 
    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0 
 
    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
    # Print energies of each bound state from the analytical model
    find_analytic_energies(en)   
 
if __name__ == "__main__":
    main()
That's curious. The problem I'm working to sort out is the script's inherent instability; using globals makes it a bit less predictable. Have you checked the source code for solve_ivp() yet?
ya, i'm reading into it. I seems that I got past making solve_ivp, and I removed all but one global, E. now Im trying to get find all zeros to work, I may have to rewrite it, but Im getting an error I don't understand:
Error:
Traceback (most recent call last): File "C:\...\shcrdinger.py", line 81, in <module> main() File "C:...\shcrdinger.py", line 72, in main E_zeroes = find_all_zeroes(en, psi_b) # now find the energies where psi(b) = 0 File "C:\...\shcrdinger.py", line 56, in find_all_zeroes s = np.sign(y) ValueError: could not broadcast input array from shape (2,43) into shape (2)
here is the updated code:
from pylab import *
from scipy.integrate import solve_ivp
from scipy.optimize import brentq
import numpy as np

a=1
B=4
L= B+a
Vmax= 50
Vpot = False

N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis
def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    if Vpot==True:
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val)
             Xpotential.append(x)
    return val
 
def SE(z, p):
    state0 = p[1]
    state1 = 1.0*(V(z) - E)*p[0]
    return array([state0, state1])
 
def Wave_function(energy):
    global E
    E = energy
    #        odeint(func, y0, t)
    #     solve_ivp(fun, t_span, y0)
    psi = solve_ivp(SE, [-B-a, L], np.array(psi0)).y
    return psi
 
def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    print(y)
    s = np.sign(y)
    for i in range(len(y.t)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes
 
def main():
    # main program        
 
    en = linspace(0, Vo, 100)   # vector of energies where we look for the stable states
 
    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0 
 
    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
     
 
if __name__ == "__main__":
    main()
Pages: 1 2