Matrix Calculator
can do matrix addition, substraction, multiplication, transposing, and inverse
Matrix Calculator
can do matrix addition, substraction, multiplication, transposing, and inverse
1 //34:33
2 import java.util.Scanner;
3
4 public class array2DComplexMatrix {
5 public static void main (String [] Args) {
6 //------------------------------------------- variables -------------------------------------------//
7 Scanner input = new Scanner (System.in);
8 int columnA;
9 int columnB = 0;
10 int rowA = 0;
11 int rowB = 0;
12 int num;
13 String ans;
14 boolean valid = false;
15
16
17 //------------------------------ getting the size of the matrices ---------------------------------//
18 do {
19 System.out.println("columns in matrix A");
20 columnA = checkint(input, true, 0);
21
22 System.out.println("rows in matrix A");
23 rowA = checkint(input, true, 0);
24
25 System.out.println("columns in matrix B");
26 columnB = checkint(input, true, 0);
27
28 System.out.println("rows in matrix B");
29 rowB = checkint(input, true, 0);
30
31 System.out.println("columnA: " + columnA + "\nrowA: " + rowA + "\ncolumnB: " + columnB + "\nrowB: " + rowB);
32
33 if (columnA == rowB || (columnA == columnB && rowA == rowB) || columnA == rowA || columnB == rowB)
34 valid = true;
35 else
36 System.out.println("matrices must be of same dimension, square, or column of matrix 1 must match row of matrix 2");
37
38 //or else user will be stuck in a loop later on in
39 } while (!valid);
40
41
42
43
44 //------------------------------ making + filling the matrices ---------------------------------//
45 int [] [] matrixA = new int [columnA] [rowA];
46 int [] [] matrixB = new int [columnB] [rowB];
47
48
49
50
51 for (int x = 0; x < columnA; x ++){
52 for (int y = 0; y < rowA; y ++){
53 System.out.println("What number is in position " + (x + 1) + ", " + (y + 1) + " for matrix A?");
54 matrixA [x] [y] = checkint(input, false, 0);
55 }
56 }
57 printMatrix("Matrix A", matrixA, rowA, columnA);
58
59 for (int x = 0; x < columnB; x ++){
60 for (int y = 0; y < rowB; y ++){
61 System.out.println("What number is in position " + (x + 1) + ", " + (y + 1) + " for matrix B?");
62 matrixB [x] [y] = checkint(input, false, 0);
63 }
64 }
65 printMatrix("Matrix B", matrixB, rowB, columnB);
66
67
68 //test
69 //inverseMatrix(matrixA, columnA);
70 //determinantMatrix(matrixA, columnA);
71 //adjointMatrix(matrixA, columnA);
72
73
74 //------------------------------------------- what operation? -------------------------------------------//
75 System.out.println("What operation?");
76 valid = false;
77 do {
78 ans = input.nextLine();
79
80 if (ans.equalsIgnoreCase("addition") || ans.equalsIgnoreCase ("+")){ //would use switch case but it doesn't work on my stupid computer
81 if (columnA == columnB && rowA == rowB){
82 addMatrix(matrixA, matrixB, columnA, rowA);
83 valid = true;
84 }
85 else
86 System.out.println("matrices must be of same dimensions");
87 }
88 //addition
89
90 else if (ans.equalsIgnoreCase("substraction") || ans.equalsIgnoreCase("-")){
91 if (columnA == columnB && rowA == rowB){
92 subMatrix(matrixA, matrixB, columnA, rowA);
93 valid = true;
94 }
95 else
96 System.out.println("matrices must be of same dimensions");
97 }
98 //substraction
99
100
101 else if (ans.equalsIgnoreCase("multiplication") || ans.equalsIgnoreCase("*") || ans.equalsIgnoreCase("x")){
102 System.out.println("Scalar?");
103 ans = input.nextLine();
104 if (ans.equalsIgnoreCase("yes")){
105 System.out.println("Matrix A or B?");
106 ans = input.nextLine();
107 if (ans.equalsIgnoreCase("Matrix A") || ans.equalsIgnoreCase("A")){
108 System.out.println("by what number?");
109 num = checkint(input, false, 0);
110 scalarMultiplication(matrixA, columnA, rowA, num );
111 valid = true;
112 }
113 else if (ans.equalsIgnoreCase("Matrix B") || ans.equalsIgnoreCase("B")){
114 System.out.println("by what number?");
115 num = checkint(input, false, 0);
116 scalarMultiplication(matrixA, columnA, rowA, num );
117
118 valid = true;
119 }
120 }
121 else if (ans.equalsIgnoreCase("no")){
122 if (columnA == rowB){
123 multiplyMatrix(matrixA, matrixB, columnA, rowA, columnB);
124 valid = true;
125 }
126 else
127 System.out.println("columns in first matrix must match rows of second");
128 }
129 }
130 //multiplication
131
132 else if (ans.equalsIgnoreCase("inverse") || ans.equalsIgnoreCase("^-1")){
133 System.out.println("matrix A or B?");
134 ans = input.nextLine();
135 if (ans.equalsIgnoreCase("matrix A") || ans.equalsIgnoreCase("A") || ans.equalsIgnoreCase("matrixA")){
136 if (columnA == rowA){
137 inverseMatrix(matrixA, columnA);
138 valid = true;
139 }
140 else
141 System.out.println("matrix must be square");
142 }
143 if (ans.equalsIgnoreCase("matrix B") || ans.equalsIgnoreCase("B") || ans.equalsIgnoreCase("matrixB")){
144 if (columnB == rowB){
145 inverseMatrix(matrixB, columnB);
146 valid = true;
147 }
148 else
149 System.out.println("matrix must be square");
150 }
151 }
152 //inverse
153
154
155 else {
156 System.out.println("invalid");
157 //invalid
158 }
159 } while (!valid);
160
161
162 }
163 //------------------------------------------- method for checking for valid inputted integers -------------------------------------------//
164 public static int checkint (Scanner input, boolean isThereMin, int min){
165 int target;
166 do {
167 while (!input.hasNextInt()){
168 System.out.println("invalid input. Try again.");
169 input.nextLine();
170 }
171 target = input.nextInt();
172 if (target <= min && isThereMin == true)
173 System.out.println("invalid input. Try again.");
174 }while (target <= min && isThereMin == true);
175 return target;
176 }
177
178 //------------------------------------------- method for printing matrices -------------------------------------------//
179 public static void printMatrix (String label, int [] [] array, int row, int column){
180 System.out.println(label);
181 for (int x = 0; x < row; x ++){
182 for (int y = 0; y < column; y ++){
183 System.out.print(array[y][x] + " ");
184 }
185 System.out.print("\n");
186 }
187 }
188
189 //------------------------------------------- method for adding matrices -------------------------------------------//
190 public static void addMatrix (int [] [] arrayA, int [] [] arrayB, int column, int row){
191 int [] [] arrayAB = new int [column] [row];
192
193 for (int x = 0; x < row; x++){
194 for (int y = 0; y < column; y++){
195 arrayAB [x] [y] = arrayA [x] [y] + arrayB [x] [y];
196 }
197 }
198 printMatrix("Matrix A + Matrix B", arrayAB, row, column);
199
200 }
201 //------------------------------------------- method for substracting matrices -------------------------------------------//
202 public static void subMatrix (int [] [] arrayA, int [] [] arrayB, int column, int row){
203 int [] [] arrayAB = new int [column] [row];
204
205 for (int x = 0; x < row; x++){
206 for (int y = 0; y < column; y++){
207 arrayAB [x] [y] = arrayA [x] [y] - arrayB [x] [y];
208 }
209 }
210 printMatrix("Matrix A - Matrix B", arrayAB, row, column);
211 }
212 //------------------------------------------- method for multiplying matrices -------------------------------------------//
213
214 public static void multiplyMatrix (int [] [] arrayA, int [] [] arrayB, int columnA, int rowA, int columnB){
215
216 int [] [] arrayAB = new int [columnB] [rowA];
217
218 for (int x = 0; x < columnB; x ++){
219 for (int y = 0; y < rowA; y ++){
220 for (int z = 0; z < columnA; z++){
221 arrayAB [x] [y] += arrayA[z][y] * arrayB [x][z];
222 }
223 }
224 }
225 printMatrix("Matrix A * Matrix B", arrayAB, rowA, columnB);
226 }
227 //------------------------------------------- method for scalar multiplying -------------------------------------------//
228 public static void scalarMultiplication (int [] [] array, int column, int row, int num){
229 int [] [] arrayFin = new int [column] [row];
230
231 for (int x = 0; x < column; x++){
232 for (int y = 0; y < row; y++) {
233 arrayFin [x] [y] = array [x] [y] * num;
234 }
235 }
236 printMatrix("Matrix", arrayFin, row, column );
237
238 }
239 //------------------------------------------- method for transposing matrices -------------------------------------------//
240 public static int [] [] transpose (int [] [] array, int column, int row){
241 int [] [] transposed = new int [row] [column];
242
243 for (int x = 0; x < column; x ++){
244 for (int y = 0; y < row; y ++){
245 transposed [y] [x] = array [x] [y];
246 }
247 }
248 printMatrix("transposed", transposed, row, column);
249 return transposed;
250 }
251
252
253 //------------------------------------------- method for finding inverse of matrices -------------------------------------------//
254 public static void inverseMatrix (int [] [] array, int column){
255 double det;
256 det = determinantMatrix(array, column);
257 det = 1/det;
258 if (det == 0){
259 System.out.println("impossible. Determinant is 0");
260 }
261 else{
262 double [] [] inverse = new double [column] [column];
263 int [] [] adjoint = adjointMatrix(array, column);
264
265 for (int x = 0; x < column; x ++){
266 for (int y = 0; y < column; y++){
267 inverse [x] [y] = det * adjoint [x] [y];
268 }
269 }
270
271 System.out.println("inverse Matrix");
272 for (int x = 0; x < column; x ++){
273 for (int y = 0; y < column; y ++){
274 System.out.print(inverse[y][x] + " ");
275 }
276 System.out.print("\n");
277 }
278 }
279
280 }
281
282 public static int [] [] adjointMatrix (int [] [] array, int column){
283 int [] [] adjoint = new int [column] [column];
284 int [] [] arraySml;
285
286 for (int x = 0; x < column; x ++){
287 for (int y = 0; y < column; y ++) { //rows
288 arraySml = findSmallerArray(array, column, x, y);
289 //printMatrix("small array", arraySml, (column-1), (column -1));
290 adjoint [x] [y] = determinantMatrix(arraySml, (column-1));
291 if ((x + y)%2 != 0)
292 adjoint [x] [y] *= -1;
293 //System.out.println(adjoint[x][y]);
294 }
295 }
296 adjoint = transpose(adjoint,column,column);
297
298 printMatrix ("adjoint", adjoint, column, column);
299 return adjoint;
300
301 }
302
303 public static int determinantMatrix(int [] [] array, int column){
304 int det = 0;
305 int num;
306
307 if (array.length == 2){
308 det = (array [0][0]*array[1][1]) - (array [1][0] * array[0][1]);
309 }
310 else
311 {
312
313 for (int x = 0; x < column; x ++){
314 int [] [] arraySml = findSmallerArray(array, column, x, 0);
315 //printMatrix ("smaller matrix", arraySml, (column - 1), (column - 1)); //ignore
316 if (arraySml.length == 2){
317 //System.out.println("smaller matrix is 2... determinant is array sml wtv wtv");
318 if ((x + 1)%2 != 0)
319 det += (array[x] [0])* ((arraySml [0][0]*arraySml[1][1]) - (arraySml [1][0] * arraySml[0][1]));
320 else
321 det -= (array [x] [0])*((arraySml [0][0]*arraySml[1][1]) - (arraySml [1][0] * arraySml[0][1]));
322 }
323 else if ((x + 1)%2 != 0){
324 num = array [x] [0] * determinantMatrix(arraySml, (column - 1));
325 System.out.println("num " + num);
326 det += num;
327 }
328 else if ((x + 1)%2 == 0){
329 num = array [x][0] * determinantMatrix(arraySml, (column-1));
330 //System.out.println("num " + num);
331 det -= num;
332 }
333
334 //System.out.println("final calculation: " + det);
335 }
336 }
337 //System.out.println("det" + det);
338 return det;
339
340 }
341
342 public static int [] [] findSmallerArray(int [] [] array, int column, int targetColumn, int targetRow){ //cofactor
343 int [] [] arraySml = new int [column -1] [column - 1];
344 int rowCnt = 0;
345 int columnCnt = 0;
346
347 for (int x = 0; x < column - 1 ; x++){ //columns
348 if (columnCnt == targetColumn)
349 columnCnt++;
350 for (int y = 0; y < column - 1; y ++){ //rows
351 if (rowCnt == targetRow)
352 rowCnt++;
353 arraySml [x] [y] = array [columnCnt] [rowCnt];
354 rowCnt++;
355 }
356 rowCnt = 0;
357 columnCnt++;
358 }
359 return arraySml;
360 }
361
362
363 }
Pong AI
Reinforcement learning agent that learns to play pong
1 import pygame
2 import sys
3 from pygame.math import Vector2
4
5 pygame.init()
6
7 screen = pygame.display.set_mode((640, 480))
8 clock = pygame.time.Clock()
9 SCREEN_UPDATE = pygame.USEREVENT
10 pygame.time.set_timer(SCREEN_UPDATE, 150)
11
12
13 class pongAI:
14 def __init__(self):
15 self.paddle1 = Vector2(25, 0)
16 self.paddle2 = Vector2(615, 0)
17 self.ball = Vector2(320, 240)
18 self.ball_velocity = Vector2(1, 1)
19 self.score = 0
20 self.last_paddle = None
21 self.reward = 0
22
23 def play_step(self, action):
24 # if game is closed
25 for event in pygame.event.get():
26 if event.type == pygame.QUIT:
27 pygame.quit()
28 sys.exit()
29
30 # to update the visuals
31 if event.type == SCREEN_UPDATE:
32 self.update_ui()
33
34 # to move paddles
35 # pressed = pygame.key.get_pressed()
36 # if pressed[pygame.K_w]:
37 # self.paddle1.y = max(0, int(self.paddle1.y - 5))
38 # if pressed[pygame.K_s]:
39 # self.paddle1.y = min(380, int(self.paddle1.y + 5))
40 # if pressed[pygame.K_UP]:
41 # self.paddle2.y = max(0, int(self.paddle2.y - 5))
42 # if pressed[pygame.K_DOWN]:
43 # self.paddle2.y = min(380, int(self.paddle2.y + 5))
44
45 # to move paddles with action [w , s , up, down]
46 if action == [1, 0, 0, 0]:
47 old = self.paddle1.y
48 self.paddle1.y = max(0, int(self.paddle1.y - 5))
49 if abs(self.paddle1.y - self.ball.y) < abs(old - self.ball.y):
50 self.reward = 1
51 if action == [0, 1, 0, 0]:
52 old = self.paddle1.y
53 self.paddle1.y = min(380, int(self.paddle1.y + 5))
54 if abs(self.paddle1.y - self.ball.y) < abs(old - self.ball.y):
55 self.reward = 1
56 if action == [0, 0, 1, 0]:
57 old = self.paddle2.y
58 self.paddle2.y = max(0, int(self.paddle2.y - 5))
59 if abs(self.paddle2.y - self.ball.y) < abs(old - self.ball.y):
60 self.reward = 1
61 if action == [0, 0, 0, 1]:
62 old = self.paddle2.y
63 self.paddle2.y = min(380, int(self.paddle2.y + 5))
64 if abs(self.paddle2.y - self.ball.y) < abs(old - self.ball.y):
65 self.reward = 1
66
67 # to move ball
68 self.ball += self.ball_velocity
69 if self.ball.y >= 475 or self.ball.y <= 0:
70 self.ball_velocity.y *= -1
71
72 # check if ball is close to paddle
73
74
75 # check if ball hit paddle
76 if self.check_paddle_hit():
77 self.ball_velocity.x *= -1
78 self.score += 1
79 self.reward = 10
80 print(self.score)
81 if self.score % 4 == 0 and self.score != 0:
82 print(self.ball_velocity)
83 self.ball_velocity.x += 1
84 self.ball_velocity.y += 1
85 print(self.ball_velocity)
86
87 done = self.check_game_over()
88 if done:
89 self.reset()
90 self.reward = -10
91
92 pygame.display.flip()
93 clock.tick(60)
94 score = self.score
95 reward = self.reward
96
97 return reward, done, score
98
99 def reset(self):
100 self.ball = Vector2(320, 240)
101 self.ball_velocity = Vector2(1, 1)
102 self.paddle1 = Vector2(25, 190)
103 self.paddle2 = Vector2(615, 190)
104 self.score = 0
105 self.last_paddle = None
106
107 def check_game_over(self):
108 if self.ball.x > 640 or self.ball.x < 0:
109 return True
110 return False
111
112 def check_paddle_hit(self):
113 # Create rectangles for the ball and paddles
114 ball_rect = pygame.Rect(self.ball.x - 5, self.ball.y - 5, 10, 10)
115 paddle1_rect = pygame.Rect(self.paddle1.x, self.paddle1.y, 15, 100)
116 paddle2_rect = pygame.Rect(self.paddle2.x, self.paddle2.y, 15, 100)
117
118 # Check for collision between ball and paddles
119 if ball_rect.colliderect(paddle1_rect):
120 if self.last_paddle == "paddle1_rect":
121 return False
122 else:
123 self.last_paddle = "paddle1_rect"
124 return True
125
126 if ball_rect.colliderect(paddle2_rect):
127 if self.last_paddle == "paddle2_rect":
128 return False
129 else:
130 self.last_paddle = "paddle2_rect"
131 return True
132
133 return False
134
135 def update_ui(self):
136 screen.fill(pygame.Color("Black"))
137 paddle1_rect = pygame.Rect(self.paddle1.x, self.paddle1.y, 10, 100)
138 paddle2_rect = pygame.Rect(self.paddle2.x, self.paddle2.y, 10, 100)
139 pygame.draw.circle(screen, pygame.Color("White"), [self.ball.x, self.ball.y], 5, 0)
140 pygame.draw.rect(screen, pygame.Color("White"), paddle1_rect)
141 pygame.draw.rect(screen, pygame.Color("White"), paddle2_rect)
142
143
144 pong = pongAI()
145
1 import torch
2 import random
3 import numpy as np
4 from collections import deque
5 import sys
6 from model import Linear_QNet, QTrainer
7 from main import pongAI
8 from helper import plot
9
10 MAX_MEMORY = 1000
11 BATCH_SIZE = 100
12 LR = 0.001
13
14
15 class Agent:
16 def __init__(self):
17 self.n_games = 0
18 self.epsilon = 0 # randomness
19 self.gamma = 0.9 # discount rate
20 self.memory = deque(maxlen=MAX_MEMORY) # like a list but will only have x amount of items before it refills
21 self.model = Linear_QNet(8, 256, 4)
22 self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
23
24 def get_state(self, game):
25 state = [
26 # get ball position [x + y values]
27 game.ball.x,
28 game.ball.y,
29
30 # get ball velocity [x + y values]
31 game.ball_velocity.x,
32 game.ball_velocity.y,
33
34 # get paddle position 1 [x + y values]
35 game.paddle1.x,
36 game.paddle1.y,
37
38 # get paddle position 2 [x + y values]
39 game.paddle2.x,
40 game.paddle2.y
41 ]
42 return np.array(state, dtype=np.float32)
43
44 def remember(self, state, action, reward, next_state, done):
45 self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached
46
47 def train_long_memory(self):
48 if len(self.memory) > BATCH_SIZE:
49 mini_sample = random.sample(self.memory, BATCH_SIZE) # makes a batch using the memory and randoms
50 else:
51 mini_sample = self.memory
52
53 states, actions, rewards, next_states, dones = zip(*mini_sample)
54 self.trainer.train_step(states, actions, rewards, next_states, dones)
55 # for state, action, reward, next_state, done in mini_sample:
56 # self.trainer.train_step(state, action, reward, next_state, done)
57
58 def train_short_memory(self, state, action, reward, next_state, done):
59 self.trainer.train_step(state, action, reward, next_state, done)
60
61 def get_action(self, state):
62 # return the move it's going to take [w , s , up, down]
63 self.epsilon = max(5, 90 - self.n_games)
64 final_move = [0, 0, 0, 0]
65 if random.randint(0, 200) < self.epsilon:
66 move = random.randint(0, 3)
67 final_move[move] = 1
68 else:
69 state0 = torch.tensor(state, dtype=torch.float)
70 prediction = self.model(state0)
71 move = torch.argmax(prediction).item()
72 final_move[move] = 1
73
74 return final_move
75
76
77 def train():
78 plot_scores = []
79 plot_mean_scores = []
80 total_score = 0
81 record = 0
82 agent = Agent()
83 game = pongAI()
84 while True:
85 # get old state
86 state_old = agent.get_state(game)
87
88 # get move
89 final_move = agent.get_action(state_old)
90
91 # perform move and get new state
92 reward, done, score = game.play_step(final_move)
93 print("score is" + score)
94 state_new = agent.get_state(game)
95
96 # train short memory
97 agent.train_short_memory(state_old, final_move, reward, state_new, done)
98
99 # remember
100 agent.remember(state_old, final_move, reward, state_new, done)
101 if done:
102 # train long memory, plot result
103 game.reset()
104 agent.n_games += 1
105 agent.train_long_memory()
106
107 if score > record:
108 record = score
109 agent.model.save()
110
111 print('Game', agent.n_games, 'Score', score, 'Record:', record)
112
113 plot_scores.append(score)
114 total_score += score
115 mean_score = total_score / agent.n_games
116 plot_mean_scores.append(mean_score)
117 plot(plot_scores, plot_mean_scores)
118
119
120 while True:
121 train()
122
1 import matplotlib.pyplot as plt
2 from IPython import display
3
4 plt.ion()
5
6
7 def plot(scores, mean_scores):
8 display.clear_output(wait=True)
9 display.display(plt.gcf())
10 plt.clf()
11 plt.title('Training...')
12 plt.xlabel('Number of Games')
13 plt.ylabel('Score')
14 plt.plot(scores)
15 plt.plot(mean_scores)
16 plt.ylim(ymin=0)
17 plt.text(len(scores) - 1, scores[-1], str(scores[-1]))
18 plt.text(len(mean_scores) - 1, mean_scores[-1], str(mean_scores[-1]))
19 plt.show(block=False)
20 plt.pause(.1)
21
1 import torch 2 import torch.nn as nn 3 import torch.optim as optim 4 import torch.nn.functional as F 5 import numpy as np 6 import os 7 8 class Linear_QNet(nn.Module): 9 def __init__(self, input_size, hidden_size1, output_size): 10 super().__init__() 11 self.linear1 = nn.Linear(input_size, hidden_size1) 12 self.linear2 = nn.Linear(hidden_size1, output_size) 13 14 def forward(self, x): 15 x = F.relu(self.linear1(x)) 16 x = self.linear2(x) 17 return x 18 19 def save(self, file_name='model.pth'): 20 model_folder_path = './model' 21 if not os.path.exists(model_folder_path): 22 os.makedirs(model_folder_path) 23 24 file_name = os.path.join(model_folder_path, file_name) 25 torch.save(self.state_dict(), file_name) 26 27 28 class QTrainer: 29 def __init__(self, model, lr, gamma): 30 self.lr = lr 31 self.gamma = gamma 32 self.model = model 33 self.optimizer = optim.Adam(model.parameters(), lr=self.lr) 34 self.criterion = nn.MSELoss() 35 36 def train_step(self, state, action, reward, next_state, done): 37 state = np.array(state) 38 state = torch.tensor(state, dtype=torch.float) 39 next_state = np.array(next_state) 40 next_state = torch.tensor(next_state, dtype=torch.float) 41 action = np.array(action) 42 action = torch.tensor(action, dtype=torch.long) 43 reward = np.array(reward) 44 reward = torch.tensor(reward, dtype=torch.float) 45 # (n, x) 46 47 if len(state.shape) == 1: 48 # (1, x) 49 state = torch.unsqueeze(state, 0) 50 next_state = torch.unsqueeze(next_state, 0) 51 action = torch.unsqueeze(action, 0) 52 reward = torch.unsqueeze(reward, 0) 53 done = (done,) 54 55 # 1: predicted Q values with current state 56 pred = self.model(state) 57 58 target = pred.clone() 59 for idx in range(len(done)): 60 Q_new = reward[idx] 61 if not done[idx]: 62 Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) 63 64 target[idx][torch.argmax(action[idx]).item()] = Q_new 65 66 # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done 67 # pred.clone() 68 # preds[argmax(action)] = Q_new 69 self.optimizer.zero_grad() 70 loss = self.criterion(target, pred) 71 loss.backward() 72 73 self.optimizer.step() 74