Pytoch: multidimensional linear regression

Time:2020-11-22

1. Objectives

Fit function $f (x) = 5.0x_ 1+4.0x_ 2+3.0x_ 3+3 $

2. Theory

It is similar to one-dimensional linear regression.

3. Implementation

3.0 environment

python == 3.6
torch == 1.4

3.1 necessary packages

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

3.2 creating data and transforming forms

# f(x)=5x1+4x2+3x3+3
x_train = np.array([[1,3,4],[2,4,2],[7,5,9], [2,5,6], [6,4,2],[8,2,7],[9,3,6],[1,6,8], [5,3,6],[3,7,3]], dtype=np.float32)
y_train = x_train[:,0]*5+x_train[:,1]*4+3*x_train[:,2]+3
y_train = y_train.reshape((10,1))

x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

3.3 building models and creating objects

class MultiLinearRegression(nn.Module):
    def __init__(self):
        super(MultiLinearRegression, self).__init__()
        self.linear  = nn.Linear (3,1) ා because three variables map to one output
        
    def forward(self,x):
        out = self.linear(x)
        return out

model = MultiLinearRegression()

3.4 check CUDA

if torch.cuda.is_available():
    model = model.cuda()
    x_train = x_train.cuda()
    y_train = y_train.cuda()

3.5 select optimizer

The mean square error is used here, and the learning rate is 0.001

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

3.6 start training

epoch = 0
while True:
    output = model(x_ Forward propagation
    loss = criterion(output, y_ Loss calculation
    loss_ value =  loss.data.cpu (). Numpy() ා get loss value
    optimizer.zero_ Grad () ා gradient zero
    loss.backward () reverse propagation
    optimizer.step () update gradient
    
    epoch += 1
    If epoch% 100 = = 0: ා print once every 100 steps
        print('Epoch:{}, loss:{:.6f}'.format(epoch, loss_value))
    if loss_value <= 1e-3:
        break

3.7 view results

w = model.linear.weight.data.cpu().numpy()
b = model.linear.bias.data.cpu().numpy()
print('w:{},b:{}'.format(w,b))

#The result is
w:[[5.0077577 4.0204782 3.004031 ]],b:[2.851891]

4. Comments are welcome