Python Forum
odeint to solve Schrodinger equation
Thread Rating:
  • 0 Vote(s) - 0 Average
  • 1
  • 2
  • 3
  • 4
  • 5
odeint to solve Schrodinger equation
#1
Hi, Im trying to solve the Schrodinger equation. I am basing myself on this site but in altering the code odeint is giving me the wrong results. the functions find_all_zeroes(x,y) and find_analytic_energies(en) are supposed to give me the the same results but they are vastly different. This is the altered code I am using for the part in question. Can someone tell me what I am doing wrong?

from pylab import *
from scipy.integrate import odeint
from scipy.optimize import brentq

a=1
B=4
L= B+a
Vmax= 50
Vpot = False
def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    if Vpot==True:
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val)
             Xpotential.append(x)
    return val
 
def SE(psi, x):
    """
    Returns derivatives for the 1D schrodinger eq.
    Requires global value E to be set somewhere. State0 is first derivative of the
    wave function psi, and state1 is its second derivative.
    """
    state0 = psi[1]
    state1 = 2.0*(V(x) - E)*psi[0]
    return array([state0, state1])
 
def Wave_function(energy):
    """
    Calculates wave function psi for the given value
    of energy E and returns value at point b
    """
    global psi
    global E
    E = energy
    psi = odeint(SE, psi0, x)
    return psi[-1,0]
 
def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    s = sign(y)
    for i in range(len(y)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes
 
def find_analytic_energies(en):
    """
    Calculates Energy values for the finite square well using analytical
    model (Griffiths, Introduction to Quantum Mechanics, 1st edition, page 62.)
    """
    z = sqrt(2*en)
    z0 = sqrt(2*Vo)
    z_zeroes = []
    f_sym = lambda z: tan(z)-sqrt((z0/z)**2-1)      # Formula 2.138, symmetrical case
    f_asym = lambda z: -1/tan(z)-sqrt((z0/z)**2-1)  # Formula 2.138, antisymmetrical case
 
    # first find the zeroes for the symmetrical case
    s = sign(f_sym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_sym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("Energies from the analyitical model are: ")
    print ("Symmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=(2n-1)pi/2 solutions cause that's where tan(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))
    # Now for the asymmetrical
    z_zeroes = []
    s = sign(f_asym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_asym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("(Antisymmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=npi solutions cause that's where ctg(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))
 
N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis
 
def main():
    # main program        
 
    en = linspace(0, Vo, 1000000)   # vector of energies where we look for the stable states
 
    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0 
 
    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
    # Print energies of each bound state from the analytical model
    find_analytic_energies(en)   
 
    # Plot wave function values at b vs energy vector
    figure()
    plot(en/Vo,psi_b)
    title('Values of the $\Psi(b)$ vs. Energy')
    xlabel('Energy, $E/V_0$')
    ylabel('$\Psi(x = b)$', rotation='horizontal')
    for E in E_zeroes:
        plot(E/Vo, [0], 'go')
        annotate("E = %.2f"%E, xy = (E/Vo, 0), xytext=(E/Vo, 30))
    grid()
 
    # Plot the wavefunctions for first 4 eigenstates
    figure(2)
    for E in E_zeroes[0:4]:
        Wave_function(E)
        plot(x, psi[:,0], label="E = %.2f"%E)
    legend(loc="upper right")
    title('Wave function')
    xlabel('x, $x/L$')
    ylabel('$\Psi(x)$', rotation='horizontal', fontsize = 15)
    grid()

    figure(3)
    pot =[]
    for i in x:
        pot.append(V(i))
    plot(x,pot)
    show()
if __name__ == "__main__":
    main()
Reply
#2
That code needs some serious refactoring. I've been digging into it for a few... twenty minutes... -ish... and I haven't sussed out the deal with E. That global variable doesn't appear do actually do anything in the code. Every subsequent instance when a variable E is used, the variable is instantiated in a for loop, so the global isn't being used at all, it seems. Plus, it gets set as a global from a function - very bad practice.

Can you provide the expected results for the problematic functions when given known arguments?
Reply
#3
the function find_analytic_energies(en) finds the correct values, the list E_zeroes (populated in line 109) should contain the same values, problem is that I need the function Wave function to work for other things down the line. If I try to use the values generated via ind_analytic_energies(en) with Wave function(energy) I do not get the expected results, ie the wave isn't at zero on both ends of V(x).
Reply
#4
Okay, I'm going to work on cleaning up the code. There are some peculiarities that I will need explained by someone who understands the calculations at work. So, I will post my marked up version later with lots of line comments. If you could clarify those lines, I would appreciate it.
Reply
#5
thanks, I'll try to answer everything, I sould be able to since I did the math by hand. I also asked here, though it concerns it's more concerned with the method used and not code debugging.
Reply
#6
I uploaded the file to my GitHub. I added some line comments with questions about the code. Right now, I need clarification on what each function and variable truly is. The names the original author selected are not descriptive so I don't know if V() is for Vector() or Velocity() or any number of other concepts beginning with a "v".

For the first round of refactoring, I'm going to focus on lateral changes to correct some of the bad practices (such as setting globals from inside a function). That should not change anything operationally.
Reply
#7
I think I answered all of the questions in the code
V() would work better as a closure or class method Needs a more descriptive name; what is V? Vector? Velocity? V is the potential function for the system, eg imagin ammonia molecule, it is shaped like a pyramid with three hydrogen at the base and nitrogen at the top. the potential Vo represents the area with the hydrogen atoms. the area where V(x)= 0 are the places that you'd probably find N. when Wave_function is called and all the zeros are found i have the energy with that energy I can plot the wave function onto a graph, but like the wave function is not being solved correctly because when I plot the wave using the correct energy that is found by find_analytic_energies, at x =-a-B the wave function is zero( which is correct) but at x = L the wave function is not zero, like this picture:
[Image: Infinite+Potential+Well+…+bottom+line.jpg]
that leads me to conclude that the problem is with the solution of the system of ordinary differential equations. I can't tell if I followed the advice given to me on scicomp
"If you really want to deal with an infinite potential well, then you should set b=L
and enforce the boundary condition ψ(b)=0. In this case it also makes sense
to start shooting at x=−b, with ψ(−b)=0and ψ′(b)nonzero." - LonelyProf

from pylab import *
from scipy.integrate import odeint
from scipy.optimize import brentq

a=1 # Never changes. Constant? yes
B=4 # Never changes. Constant? yes
L= B+a # Never changes. Constant? yes these tree are here more for me to easily change
Vmax= 50
Vpot = False


# V() would work better as a closure or class method
# Needs a more descriptive name; what is V? Vector? Velocity?
# V is the potential function for the sistem, eg imagin ammonia molecule,
# it is shaped like a pyramid
 with three hydrogen at the baseand nitrogen
# at the top. the potential Vo represents the area with the the hydrogen
# atoms. the area where V(x)= 0 are the places that you'd probably find N.
# when Wave_function is called and all the zeros are found i have the energy
# with that energy I can plot the wave function onto a graph, but like the
# wave function is not being solved correctly because when I plot  the wave
# using the correct energy that is found by find_analytic_energies, at x =-a-B
# the wave function is zero( whichis correct) but at x = L the wave function is not zero
# that leads me to conclude that the problem is with the solution of the system of ordinary
# differential equations. I can't tell if I followed the advice given to me on scicomp
# "If you really want to deal with an infinite potential well, then you should set  b=L
# and enforce the boundary condition ψ(b)=0. In this case it also makes sense
# to start shooting at x=−b, with ψ(−b)=0and ψ′(b)nonzero." - LonelyProf

def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    # This conditional can never be entered     #### this is here for parts of the code that come later on
    ##                                               I tried reducing things here to a min. for the problem to be clearer
    if Vpot==True: # never the case, Vpot does not change
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val) # sequence does not exist
             Xpotential.append(x) # sequence does not exist
    return val

def SE(psi, x):
    """
    Returns derivatives for the 1D schrodinger eq.
    Requires global value E to be set somewhere. State0 is first derivative of the
    wave function psi, and state1 is its second derivative.
    """
    state0 = psi[1]
    state1 = 2.0*(V(x) - E)*psi[0]
    return array([state0, state1])

def Wave_function(energy):
    """
    Calculates wave function psi for the given value
    of energy E and returns value at point b
    """
    global psi # Functions should not call global variables
    global E # Functions should not call global variables
    E = energy # Functions should not set global variables from within
    psi = odeint(SE, psi0, x) # Functions should not set global variables from within
    return psi[-1,0]

def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    s = sign(y)
    for i in range(len(y)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes

def find_analytic_energies(en):
    """
    Calculates Energy values for the finite square well using analytical
    model (Griffiths, Introduction to Quantum Mechanics, 1st edition, page 62.)
    """
    z = sqrt(2*en)
    z0 = sqrt(2*Vo)
    z_zeroes = []
    f_sym = lambda z: tan(z)-sqrt((z0/z)**2-1)      # Formula 2.138, symmetrical case
    f_asym = lambda z: -1/tan(z)-sqrt((z0/z)**2-1)  # Formula 2.138, antisymmetrical case

    # first find the zeroes for the symmetrical case
    s = sign(f_sym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_sym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("Energies from the analyitical model are: ")
    print ("Symmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=(2n-1)pi/2 solutions cause that's where tan(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))
    # Now for the asymmetrical
    z_zeroes = []
    s = sign(f_asym(z))
    for i in range(len(s)-1):   # find zeroes of this crazy function
       if s[i]+s[i+1] == 0:
           zero = brentq(f_asym, z[i], z[i+1])
           z_zeroes.append(zero)
    print ("(Antisymmetrical case)")
    for i in range(0, len(z_zeroes),2):   # discard z=npi solutions cause that's where ctg(z) is discontinous
        print ("%.4f" %(z_zeroes[i]**2/2))

N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis

def main():
    # main program

    en = linspace(0, Vo, 1000000)   # vector of energies where we look for the stable states

    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0

    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
    # Print energies of each bound state from the analytical model
    find_analytic_energies(en)

    # Plot wave function values at b vs energy vector
    figure()
    plot(en/Vo,psi_b)
    title('Values of the $\Psi(b)$ vs. Energy')
    xlabel('Energy, $E/V_0$')
    ylabel('$\Psi(x = b)$', rotation='horizontal')
    for E in E_zeroes:
        plot(E/Vo, [0], 'go')
        annotate("E = %.2f"%E, xy = (E/Vo, 0), xytext=(E/Vo, 30))
    grid()

    # Plot the wavefunctions for first 4 eigenstates
    figure(2)
    for E in E_zeroes[0:4]:
        Wave_function(E)
        plot(x, psi[:,0], label="E = %.2f"%E)
    legend(loc="upper right")
    title('Wave function')
    xlabel('x, $x/L$')
    ylabel('$\Psi(x)$', rotation='horizontal', fontsize = 15)
    grid()

    figure(3)
    pot =[]
    for i in x:
        pot.append(V(i))
    plot(x,pot)
    show()
if __name__ == "__main__":
    main()
Reply
#8
I'm currently trying to swap odeint out for scipy.integrate.solve_ivp because on odeint's page it says this:
Quote:Note
For new code, use scipy.integrate.solve_ivp to solve a differential equation.


I tried swapping, but I keep getting this error:
Error:
...Programs\Python\Python36-32\lib\site-packages\scipy\integrate\_ivp\rk.py", line 67, in rk_step K[0] = f ValueError: could not broadcast input array from shape (2,2) into shape (2)
this is what wavefunction looks like now:
from pylab import *
from scipy.integrate import solve_ivp
from scipy.optimize import brentq

a=1
B=4
L= B+a
Vmax= 50
Vpot = False

N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis
def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    if Vpot==True:
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val)
             Xpotential.append(x)
    return val
 
def SE(x, p):
    state0 = psi[1]
    state1 = 1.0*(V(x) - E)*psi[0]
    return array([state0, state1])
 
def Wave_function(energy):
    E = energy
    #        odeint(func, y0, t)
    #     solve_ivp(fun, t_span, y0)
    psi = solve_ivp(SE, [-B-a, L], np.array(psi0))
    print(psi)
    return psi[-1,0]
 
def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    s = sign(y)
    for i in range(len(y)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes
 
def main():
    # main program        
 
    en = linspace(0, Vo, 100)   # vector of energies where we look for the stable states
 
    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0 
 
    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
    # Print energies of each bound state from the analytical model
    find_analytic_energies(en)   
 
if __name__ == "__main__":
    main()
Reply
#9
That's curious. The problem I'm working to sort out is the script's inherent instability; using globals makes it a bit less predictable. Have you checked the source code for solve_ivp() yet?
Reply
#10
ya, i'm reading into it. I seems that I got past making solve_ivp, and I removed all but one global, E. now Im trying to get find all zeros to work, I may have to rewrite it, but Im getting an error I don't understand:
Error:
Traceback (most recent call last): File "C:\...\shcrdinger.py", line 81, in <module> main() File "C:...\shcrdinger.py", line 72, in main E_zeroes = find_all_zeroes(en, psi_b) # now find the energies where psi(b) = 0 File "C:\...\shcrdinger.py", line 56, in find_all_zeroes s = np.sign(y) ValueError: could not broadcast input array from shape (2,43) into shape (2)
here is the updated code:
from pylab import *
from scipy.integrate import solve_ivp
from scipy.optimize import brentq
import numpy as np

a=1
B=4
L= B+a
Vmax= 50
Vpot = False

N = 1000                  # number of points to take
psi = np.zeros([N,2])     # Wave function values and its derivative (psi and psi')
psi0 = array([0,1])   # Wave function initial states
Vo = 50
E = 0.0                   # global variable Energy  needed for Sch.Eq, changed in function "Wave function"
b = L                     # point outside of well where we need to check if the function diverges
x = linspace(-B-a, L, N)    # x-axis
def V(x):
    '''
    #Potential function in the finite square well.
    '''
    if -a <=x <=a:
        val = Vo
    elif x<=-a-B:
        val = Vmax
    elif x>=L:
        val = Vmax
    else:
        val = 0
    if Vpot==True:
          if -a-B-(10/N) < x <= L+(1/N):
             Ypotential.append(val)
             Xpotential.append(x)
    return val
 
def SE(z, p):
    state0 = p[1]
    state1 = 1.0*(V(z) - E)*p[0]
    return array([state0, state1])
 
def Wave_function(energy):
    global E
    E = energy
    #        odeint(func, y0, t)
    #     solve_ivp(fun, t_span, y0)
    psi = solve_ivp(SE, [-B-a, L], np.array(psi0)).y
    return psi
 
def find_all_zeroes(x,y):
    """
    Gives all zeroes in y = Psi(x)
    """
    all_zeroes = []
    print(y)
    s = np.sign(y)
    for i in range(len(y.t)-1):
        if s[i]+s[i+1] == 0:
            zero = brentq(Wave_function, x[i], x[i])
            all_zeroes.append(zero)
    return all_zeroes
 
def main():
    # main program        
 
    en = linspace(0, Vo, 100)   # vector of energies where we look for the stable states
 
    psi_b = []      # vector of wave function at x = b for all of the energies in en
    for e1 in en:
        psi_b.append(Wave_function(e1))     # for each energy e1 find the the psi(x) at x = b
    
    E_zeroes = find_all_zeroes(en, psi_b)   # now find the energies where psi(b) = 0 
 
    # Print energies for the bound states
    print ("Energies for the bound states are: ")
    for E in E_zeroes:
        print ("%.2f" %E)
     
 
if __name__ == "__main__":
    main()
Reply


Possibly Related Threads…
Thread Author Replies Views Last Post
  Mukhanov equation + odeint Messier087 0 1,622 Mar-28-2020, 04:03 PM
Last Post: Messier087
  Odeint to solve Mukhanov equation Messier087 4 2,570 Feb-17-2020, 05:05 PM
Last Post: Messier087
  Can't find a way to solve this nonlinear equation Alex009988 2 2,659 Aug-16-2019, 01:50 AM
Last Post: scidam
  python odeint keeps giving me size of array error kiyoshi7 1 6,160 Nov-01-2018, 02:03 AM
Last Post: j.crater

Forum Jump:

User Panel Messages

Announcements
Announcement #1 8/1/2020
Announcement #2 8/2/2020
Announcement #3 8/6/2020