Python Forum

Full Version: What is wrong with this implementation of the cost function for linear regression?
You're currently viewing a stripped down version of our content. View the full version with proper formatting.
I am trying to write the cost function for regularised linear regression. (I am doing exercise 5 from the Stanford Coursera course on machine learning.) In the code below I also include a print of the data so that you can copy it into a .txt file if you want to try it yourself.

import numpy as np
from scipy.io import loadmat

data = loadmat('ex5data1.mat')
X = data['X']
X = np.insert(X, 0, 1, axis=1)
y = data['y']
theta = np.ones((2, 1))

print X
print y
print ' '

def cost_function(theta, X, y, reg_param):
	theta = np.matrix(theta)
	X = np.matrix(X)
	y = np.matrix(y)
	m = y.shape
	h = X * theta #This is an mxn matrix with an nx1 matrix to give h being mx1
	error = np.power((h - y), 2) #Elementwise subtraction of y(mx1) from h(mx1)
        term1 = np.sum(error) / 2*y.shape[0] #Non-regularisation term in the cost function
	reg = (reg_param * np.sum(np.power(theta[1:, :], 2))) / 2*y.shape[0] #Regularisation term in the cost function
	return term1 + reg
	
print cost_function(X, y, theta, 1)
Output:
[[ 1. -15.93675813] [ 1. -29.15297922] [ 1. 36.18954863] [ 1. 37.49218733] [ 1. -48.05882945] [ 1. -8.94145794] [ 1. 15.30779289] [ 1. -34.70626581] [ 1. 1.38915437] [ 1. -44.38375985] [ 1. 7.01350208] [ 1. 22.76274892]] [[ 2.13431051] [ 1.17325668] [ 34.35910918] [ 36.83795516] [ 2.80896507] [ 2.12107248] [ 14.71026831] [ 2.61418439] [ 3.74017167] [ 3.73169131] [ 7.62765885] [ 22.7524283 ]] 43775.0196797
The correct output should be ~303, so mine is waaaay off, meaning that I must be making some big mistake in my implementation.

Solved

Silly mistake: make sure that you don't make the following mistake when dividing in python:

1 / 2*m is not the same as 1 / (2*m)
(Dec-22-2017, 07:24 PM)JoeB Wrote: [ -> ]make sure that you don't make the following mistake when dividing in python


well, not only python. They are not same also in 'standard' math calculations :-)