-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLassoRegression.py
More file actions
43 lines (37 loc) · 877 Bytes
/
LassoRegression.py
File metadata and controls
43 lines (37 loc) · 877 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#-*-coding:utf-8-*-
import numpy as np
class LassoRegression(object):
"""docstring for LassoRegression"""
def __init__(self, iters=300,epos=0.01):
self.iters = iters
self.epos = epos
@staticmethod
def squareError(y,yhat):
return np.sum((y - yhat) ** 2)
def fit(self,X,y):
'''
params:
X: n * m
y: 1 * m
'''
n = X.shape[0]
self.weight = np.zeros((n,1))
yhat = np.dot(self.weight.T,X)
self.squareerror = self.squareError(y,yhat)
for i in range(self.iters):
for j in range(n):
for k in [-1,1]:
w = self.weight.copy()
w[j] += self.epos * k
yhat = np.dot(w.T,X)
SETest = self.squareError(y,yhat)
if SETest < self.squareerror:
self.weight = w
self.squareerror = SETest
#Test
W = np.array([[1],[0],[2]])
X = np.random.randn(3,5)
y = np.dot(W.T,X)
lr = LassoRegression()
lr.fit(X, y)
print(lr.weight)