Projects

Matrix Calculator

can do matrix addition, substraction, multiplication, transposing, and inverse

1 //34:33
2 import java.util.Scanner;
3 
4 public class array2DComplexMatrix {
5   public static void main (String [] Args) {
6 //------------------------------------------- variables -------------------------------------------//
7     Scanner input = new Scanner (System.in);
8     int columnA;
9     int columnB = 0;
10     int rowA = 0;
11     int rowB = 0;
12     int num;
13     String ans;
14     boolean valid = false;
15     
16     
17 //------------------------------ getting the size of the matrices ---------------------------------//
18     do {
19       System.out.println("columns in matrix A");
20       columnA = checkint(input, true, 0);
21       
22       System.out.println("rows in matrix A");
23       rowA = checkint(input, true, 0);
24       
25       System.out.println("columns in matrix B");
26       columnB = checkint(input, true, 0); 
27       
28       System.out.println("rows in matrix B");
29       rowB = checkint(input, true, 0);
30       
31       System.out.println("columnA: " + columnA + "\nrowA: " + rowA + "\ncolumnB: " + columnB + "\nrowB: " + rowB);
32       
33       if (columnA == rowB || (columnA == columnB && rowA == rowB) || columnA == rowA || columnB == rowB)
34         valid = true;
35       else
36         System.out.println("matrices must be of same dimension, square, or column of matrix 1 must match row of matrix 2"); 
37       
38       //or else user will be stuck in a loop later on in
39     } while (!valid);
40     
41     
42     
43     
44 //------------------------------ making  + filling the matrices ---------------------------------//
45     int [] [] matrixA = new int [columnA] [rowA];
46     int [] [] matrixB = new int [columnB] [rowB];
47     
48     
49     
50     
51     for (int x = 0; x < columnA; x ++){
52       for (int y = 0; y < rowA; y ++){
53         System.out.println("What number is in position " + (x + 1) + ", " + (y + 1) + " for matrix A?");
54         matrixA [x] [y] = checkint(input, false, 0);
55       }
56     }
57     printMatrix("Matrix A", matrixA, rowA, columnA);
58     
59     for (int x = 0; x < columnB; x ++){
60       for (int y = 0; y < rowB; y ++){
61         System.out.println("What number is in position " + (x + 1) + ", " + (y + 1) + " for matrix B?");
62         matrixB [x] [y] = checkint(input, false, 0);
63       }
64     }
65     printMatrix("Matrix B", matrixB, rowB, columnB);
66     
67     
68     //test
69     //inverseMatrix(matrixA, columnA);
70     //determinantMatrix(matrixA, columnA);
71     //adjointMatrix(matrixA, columnA);
72     
73     
74     //------------------------------------------- what operation? -------------------------------------------//
75     System.out.println("What operation?");
76     valid = false;
77     do {
78       ans = input.nextLine();
79       
80       if (ans.equalsIgnoreCase("addition") || ans.equalsIgnoreCase ("+")){ //would use switch case but it doesn't work on my stupid computer 
81         if (columnA == columnB && rowA == rowB){
82           addMatrix(matrixA, matrixB, columnA, rowA);
83           valid = true;
84         }
85         else 
86           System.out.println("matrices must be of same dimensions");
87       }
88       //addition
89       
90       else if (ans.equalsIgnoreCase("substraction") || ans.equalsIgnoreCase("-")){
91         if (columnA == columnB && rowA == rowB){
92           subMatrix(matrixA, matrixB, columnA, rowA);
93           valid = true;
94         }
95         else 
96           System.out.println("matrices must be of same dimensions");
97       }
98       //substraction
99       
100       
101       else if (ans.equalsIgnoreCase("multiplication") || ans.equalsIgnoreCase("*") || ans.equalsIgnoreCase("x")){
102         System.out.println("Scalar?");
103         ans = input.nextLine();
104         if (ans.equalsIgnoreCase("yes")){
105           System.out.println("Matrix A or B?");
106           ans = input.nextLine();
107           if (ans.equalsIgnoreCase("Matrix A") || ans.equalsIgnoreCase("A")){
108             System.out.println("by what number?");
109             num = checkint(input, false, 0);
110             scalarMultiplication(matrixA, columnA, rowA, num );
111             valid = true;
112           }
113           else if (ans.equalsIgnoreCase("Matrix B") || ans.equalsIgnoreCase("B")){
114             System.out.println("by what number?");
115             num = checkint(input, false, 0);
116             scalarMultiplication(matrixA, columnA, rowA, num );
117               
118             valid = true;
119           }
120         }
121         else if (ans.equalsIgnoreCase("no")){
122           if (columnA == rowB){
123             multiplyMatrix(matrixA, matrixB, columnA, rowA, columnB);
124             valid = true;
125           }
126           else 
127             System.out.println("columns in first matrix must match rows of second");
128         }
129       }
130       //multiplication
131       
132       else if (ans.equalsIgnoreCase("inverse") || ans.equalsIgnoreCase("^-1")){
133         System.out.println("matrix A or B?");
134         ans = input.nextLine();
135         if (ans.equalsIgnoreCase("matrix A") || ans.equalsIgnoreCase("A") || ans.equalsIgnoreCase("matrixA")){
136           if (columnA == rowA){
137             inverseMatrix(matrixA, columnA);
138             valid = true;
139           }
140           else 
141             System.out.println("matrix must be square");
142         }
143         if (ans.equalsIgnoreCase("matrix B") || ans.equalsIgnoreCase("B") || ans.equalsIgnoreCase("matrixB")){
144           if (columnB == rowB){
145             inverseMatrix(matrixB, columnB);
146             valid = true;
147           }
148           else 
149             System.out.println("matrix must be square");
150         }
151       }
152       //inverse
153       
154       
155       else {
156         System.out.println("invalid");
157         //invalid
158       }
159     } while (!valid);
160     
161     
162   }
163   //------------------------------------------- method for checking for valid inputted integers -------------------------------------------//
164   public static int checkint (Scanner input, boolean isThereMin, int min){
165     int target;
166     do { 
167       while (!input.hasNextInt()){
168         System.out.println("invalid input. Try again.");
169         input.nextLine();
170       }
171       target = input.nextInt();
172       if (target <= min && isThereMin == true)
173         System.out.println("invalid input. Try again.");
174     }while (target <= min && isThereMin == true);
175     return target;
176   }
177   
178   //------------------------------------------- method for printing matrices -------------------------------------------//
179   public static void printMatrix (String label, int [] [] array, int row, int column){
180     System.out.println(label);
181     for (int x = 0; x < row; x ++){
182       for (int y = 0; y < column; y ++){
183         System.out.print(array[y][x] + " ");
184       }
185       System.out.print("\n");
186     }
187   }
188   
189   //------------------------------------------- method for adding matrices -------------------------------------------//
190   public static void addMatrix (int [] [] arrayA, int [] [] arrayB, int column, int row){
191     int [] [] arrayAB = new int [column] [row];
192     
193     for (int x = 0; x < row; x++){
194       for (int y = 0; y < column; y++){
195         arrayAB [x] [y] = arrayA [x] [y] + arrayB [x] [y];
196       }
197     }
198     printMatrix("Matrix A + Matrix B", arrayAB, row, column);
199     
200   }
201   //------------------------------------------- method for substracting matrices -------------------------------------------//
202   public static void subMatrix (int [] [] arrayA, int [] [] arrayB, int column, int row){
203     int [] [] arrayAB = new int [column] [row];
204     
205     for (int x = 0; x < row; x++){
206       for (int y = 0; y < column; y++){
207         arrayAB [x] [y] = arrayA [x] [y] - arrayB [x] [y];
208       }
209     }
210     printMatrix("Matrix A - Matrix B", arrayAB, row, column);
211   }
212 //------------------------------------------- method for multiplying matrices -------------------------------------------//
213   
214   public static void multiplyMatrix (int [] [] arrayA, int [] [] arrayB, int columnA, int rowA, int columnB){
215 
216     int [] [] arrayAB = new int [columnB] [rowA];
217     
218     for (int x = 0; x < columnB; x ++){
219       for (int y = 0; y < rowA; y ++){
220         for (int z = 0; z < columnA; z++){
221           arrayAB [x] [y] += arrayA[z][y] * arrayB [x][z]; 
222         }
223       }
224     }
225     printMatrix("Matrix A * Matrix B", arrayAB, rowA, columnB);
226   }
227 //------------------------------------------- method for scalar multiplying -------------------------------------------//
228   public static void scalarMultiplication (int [] [] array, int column, int row, int num){
229     int [] [] arrayFin = new int [column] [row];
230     
231     for (int x = 0; x < column; x++){
232       for (int y = 0; y < row; y++) {
233         arrayFin [x] [y] = array [x] [y] * num; 
234       }
235     }
236     printMatrix("Matrix", arrayFin, row, column );
237     
238   }
239 //------------------------------------------- method for transposing matrices -------------------------------------------//
240   public static int [] [] transpose (int [] [] array, int column, int row){
241     int [] [] transposed = new int [row] [column];
242     
243     for (int x = 0; x < column; x ++){
244       for (int y = 0; y < row; y ++){
245         transposed [y] [x] = array [x] [y];
246       }
247     }
248     printMatrix("transposed", transposed, row, column);
249     return transposed;
250   }
251   
252   
253 //------------------------------------------- method for finding inverse of matrices -------------------------------------------//
254   public static void inverseMatrix (int [] [] array, int column){
255     double det;
256     det = determinantMatrix(array, column);
257     det = 1/det;
258     if (det == 0){
259       System.out.println("impossible. Determinant is 0");
260     }
261     else{
262     double [] [] inverse = new double [column] [column];
263     int [] [] adjoint = adjointMatrix(array, column);
264     
265     for (int x = 0; x < column; x ++){
266       for (int y = 0; y < column; y++){
267         inverse [x] [y] = det * adjoint [x] [y];
268       }
269     }
270     
271     System.out.println("inverse Matrix");
272     for (int x = 0; x < column; x ++){
273       for (int y = 0; y < column; y ++){
274         System.out.print(inverse[y][x] + " ");
275       }
276       System.out.print("\n");
277     }
278     }
279     
280   }
281   
282   public static int [] [] adjointMatrix (int [] [] array, int column){
283     int [] [] adjoint = new int [column] [column];
284     int [] [] arraySml;
285     
286     for (int x = 0; x < column; x ++){
287       for (int y = 0; y < column; y ++) { //rows
288         arraySml = findSmallerArray(array, column, x, y);
289         //printMatrix("small array", arraySml, (column-1), (column -1));
290         adjoint [x] [y] = determinantMatrix(arraySml, (column-1));
291         if ((x + y)%2 != 0)
292           adjoint [x] [y] *= -1;
293         //System.out.println(adjoint[x][y]);
294       }
295     }
296     adjoint = transpose(adjoint,column,column);
297     
298     printMatrix ("adjoint", adjoint, column, column);
299     return adjoint;
300     
301   }
302   
303   public static int determinantMatrix(int [] [] array, int column){
304     int det = 0; 
305     int num;
306     
307     if (array.length == 2){
308       det = (array [0][0]*array[1][1]) - (array [1][0] * array[0][1]);
309     }
310     else
311     {
312       
313       for (int x = 0; x < column; x ++){
314         int [] [] arraySml = findSmallerArray(array, column, x, 0);
315         //printMatrix ("smaller matrix", arraySml, (column - 1), (column - 1)); //ignore 
316         if (arraySml.length == 2){
317           //System.out.println("smaller matrix is 2... determinant is array sml wtv wtv");
318           if ((x + 1)%2 != 0)
319             det += (array[x] [0])* ((arraySml [0][0]*arraySml[1][1]) - (arraySml [1][0] * arraySml[0][1]));
320           else 
321             det -= (array [x] [0])*((arraySml [0][0]*arraySml[1][1]) - (arraySml [1][0] * arraySml[0][1]));
322         }
323         else if ((x + 1)%2 != 0){
324           num = array [x] [0] * determinantMatrix(arraySml, (column - 1));
325           System.out.println("num " + num);
326           det += num;
327         }
328         else if ((x + 1)%2 == 0){
329           num = array [x][0] * determinantMatrix(arraySml, (column-1));
330           //System.out.println("num " + num);
331           det -= num;
332         }
333         
334         //System.out.println("final calculation: " + det);
335       }
336     }
337     //System.out.println("det" + det);
338     return det;
339     
340   }
341   
342   public static int [] [] findSmallerArray(int [] [] array, int column, int targetColumn, int targetRow){ //cofactor 
343     int [] [] arraySml = new int [column -1] [column - 1];
344     int rowCnt = 0;
345     int columnCnt = 0;
346     
347     for (int x = 0; x < column - 1 ; x++){ //columns
348       if (columnCnt == targetColumn)
349         columnCnt++;
350       for (int y = 0; y < column - 1; y ++){ //rows
351         if (rowCnt == targetRow)
352           rowCnt++;
353         arraySml [x] [y] = array [columnCnt] [rowCnt];
354         rowCnt++;
355       }
356       rowCnt = 0;
357       columnCnt++; 
358     } 
359     return arraySml;
360   }
361   
362   
363 }

Pong AI

Reinforcement learning agent that learns to play pong

 
1 import pygame
2 import sys
3 from pygame.math import Vector2
4 
5 pygame.init()
6 
7 screen = pygame.display.set_mode((640, 480))
8 clock = pygame.time.Clock()
9 SCREEN_UPDATE = pygame.USEREVENT
10 pygame.time.set_timer(SCREEN_UPDATE, 150)
11 
12 
13 class pongAI:
14     def __init__(self):
15         self.paddle1 = Vector2(25, 0)
16         self.paddle2 = Vector2(615, 0)
17         self.ball = Vector2(320, 240)
18         self.ball_velocity = Vector2(1, 1)
19         self.score = 0
20         self.last_paddle = None
21         self.reward = 0
22 
23     def play_step(self, action):
24         # if game is closed
25         for event in pygame.event.get():
26             if event.type == pygame.QUIT:
27                 pygame.quit()
28                 sys.exit()
29 
30             # to update the visuals
31             if event.type == SCREEN_UPDATE:
32                 self.update_ui()
33 
34         #  to move paddles
35         # pressed = pygame.key.get_pressed()
36         # if pressed[pygame.K_w]:
37         #     self.paddle1.y = max(0, int(self.paddle1.y - 5))
38         # if pressed[pygame.K_s]:
39         #     self.paddle1.y = min(380, int(self.paddle1.y + 5))
40         # if pressed[pygame.K_UP]:
41         #     self.paddle2.y = max(0, int(self.paddle2.y - 5))
42         # if pressed[pygame.K_DOWN]:
43         #     self.paddle2.y = min(380, int(self.paddle2.y + 5))
44 
45         # to move paddles with action [w , s , up, down]
46         if action == [1, 0, 0, 0]:
47             old = self.paddle1.y
48             self.paddle1.y = max(0, int(self.paddle1.y - 5))
49             if abs(self.paddle1.y - self.ball.y) < abs(old - self.ball.y):
50                 self.reward = 1
51         if action == [0, 1, 0, 0]:
52             old = self.paddle1.y
53             self.paddle1.y = min(380, int(self.paddle1.y + 5))
54             if abs(self.paddle1.y - self.ball.y) < abs(old - self.ball.y):
55                 self.reward = 1
56         if action == [0, 0, 1, 0]:
57             old = self.paddle2.y
58             self.paddle2.y = max(0, int(self.paddle2.y - 5))
59             if abs(self.paddle2.y - self.ball.y) < abs(old - self.ball.y):
60                 self.reward = 1
61         if action == [0, 0, 0, 1]:
62             old = self.paddle2.y
63             self.paddle2.y = min(380, int(self.paddle2.y + 5))
64             if abs(self.paddle2.y - self.ball.y) < abs(old - self.ball.y):
65                 self.reward = 1
66 
67         # to move ball
68         self.ball += self.ball_velocity
69         if self.ball.y >= 475 or self.ball.y <= 0:
70             self.ball_velocity.y *= -1
71 
72         # check if ball is close to paddle
73 
74 
75         # check if ball hit paddle
76         if self.check_paddle_hit():
77             self.ball_velocity.x *= -1
78             self.score += 1
79             self.reward = 10
80             print(self.score)
81             if self.score % 4 == 0 and self.score != 0:
82                 print(self.ball_velocity)
83                 self.ball_velocity.x += 1
84                 self.ball_velocity.y += 1
85                 print(self.ball_velocity)
86 
87         done = self.check_game_over()
88         if done:
89             self.reset()
90             self.reward = -10
91 
92         pygame.display.flip()
93         clock.tick(60)
94         score = self.score
95         reward = self.reward
96 
97         return reward, done, score
98 
99     def reset(self):
100         self.ball = Vector2(320, 240)
101         self.ball_velocity = Vector2(1, 1)
102         self.paddle1 = Vector2(25, 190)
103         self.paddle2 = Vector2(615, 190)
104         self.score = 0
105         self.last_paddle = None
106 
107     def check_game_over(self):
108         if self.ball.x > 640 or self.ball.x < 0:
109             return True
110         return False
111 
112     def check_paddle_hit(self):
113         # Create rectangles for the ball and paddles
114         ball_rect = pygame.Rect(self.ball.x - 5, self.ball.y - 5, 10, 10)
115         paddle1_rect = pygame.Rect(self.paddle1.x, self.paddle1.y, 15, 100)
116         paddle2_rect = pygame.Rect(self.paddle2.x, self.paddle2.y, 15, 100)
117 
118         # Check for collision between ball and paddles
119         if ball_rect.colliderect(paddle1_rect):
120             if self.last_paddle == "paddle1_rect":
121                 return False
122             else:
123                 self.last_paddle = "paddle1_rect"
124                 return True
125 
126         if ball_rect.colliderect(paddle2_rect):
127             if self.last_paddle == "paddle2_rect":
128                 return False
129             else:
130                 self.last_paddle = "paddle2_rect"
131                 return True
132 
133         return False
134 
135     def update_ui(self):
136         screen.fill(pygame.Color("Black"))
137         paddle1_rect = pygame.Rect(self.paddle1.x, self.paddle1.y, 10, 100)
138         paddle2_rect = pygame.Rect(self.paddle2.x, self.paddle2.y, 10, 100)
139         pygame.draw.circle(screen, pygame.Color("White"), [self.ball.x, self.ball.y], 5, 0)
140         pygame.draw.rect(screen, pygame.Color("White"), paddle1_rect)
141         pygame.draw.rect(screen, pygame.Color("White"), paddle2_rect)
142 
143 
144 pong = pongAI()
145 
 
1 import torch
2 import random
3 import numpy as np
4 from collections import deque
5 import sys
6 from model import Linear_QNet, QTrainer
7 from main import pongAI
8 from helper import plot
9 
10 MAX_MEMORY = 1000
11 BATCH_SIZE = 100
12 LR = 0.001
13 
14 
15 class Agent:
16     def __init__(self):
17         self.n_games = 0
18         self.epsilon = 0  # randomness
19         self.gamma = 0.9  # discount rate
20         self.memory = deque(maxlen=MAX_MEMORY)  # like a list but will only have x amount of items before it refills
21         self.model = Linear_QNet(8, 256, 4)
22         self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
23 
24     def get_state(self, game):
25         state = [
26             # get ball position [x + y values]
27             game.ball.x,
28             game.ball.y,
29 
30             # get ball velocity [x + y values]
31             game.ball_velocity.x,
32             game.ball_velocity.y,
33 
34             # get paddle position 1 [x + y values]
35             game.paddle1.x,
36             game.paddle1.y,
37 
38             # get paddle position 2 [x + y values]
39             game.paddle2.x,
40             game.paddle2.y
41         ]
42         return np.array(state, dtype=np.float32)
43 
44     def remember(self, state, action, reward, next_state, done):
45         self.memory.append((state, action, reward, next_state, done))  # popleft if MAX_MEMORY is reached
46 
47     def train_long_memory(self):
48         if len(self.memory) > BATCH_SIZE:
49             mini_sample = random.sample(self.memory, BATCH_SIZE)  # makes a batch using the memory and randoms
50         else:
51             mini_sample = self.memory
52 
53         states, actions, rewards, next_states, dones = zip(*mini_sample)
54         self.trainer.train_step(states, actions, rewards, next_states, dones)
55         # for state, action, reward, next_state, done in mini_sample:
56         #    self.trainer.train_step(state, action, reward, next_state, done)
57 
58     def train_short_memory(self, state, action, reward, next_state, done):
59         self.trainer.train_step(state, action, reward, next_state, done)
60 
61     def get_action(self, state):
62         # return the move it's going to take [w , s , up, down]
63         self.epsilon = max(5, 90 - self.n_games)
64         final_move = [0, 0, 0, 0]
65         if random.randint(0, 200) < self.epsilon:
66             move = random.randint(0, 3)
67             final_move[move] = 1
68         else:
69             state0 = torch.tensor(state, dtype=torch.float)
70             prediction = self.model(state0)
71             move = torch.argmax(prediction).item()
72             final_move[move] = 1
73 
74         return final_move
75 
76 
77 def train():
78     plot_scores = []
79     plot_mean_scores = []
80     total_score = 0
81     record = 0
82     agent = Agent()
83     game = pongAI()
84     while True:
85         # get old state
86         state_old = agent.get_state(game)
87 
88         # get move
89         final_move = agent.get_action(state_old)
90 
91         # perform move and get new state
92         reward, done, score = game.play_step(final_move)
93         print("score is" + score)
94         state_new = agent.get_state(game)
95 
96         # train short memory
97         agent.train_short_memory(state_old, final_move, reward, state_new, done)
98 
99         # remember
100         agent.remember(state_old, final_move, reward, state_new, done)
101         if done:
102             # train long memory, plot result
103             game.reset()
104             agent.n_games += 1
105             agent.train_long_memory()
106 
107             if score > record:
108                 record = score
109                 agent.model.save()
110 
111             print('Game', agent.n_games, 'Score', score, 'Record:', record)
112 
113             plot_scores.append(score)
114             total_score += score
115             mean_score = total_score / agent.n_games
116             plot_mean_scores.append(mean_score)
117             plot(plot_scores, plot_mean_scores)
118 
119 
120 while True:
121     train()
122 
1 import matplotlib.pyplot as plt
2 from IPython import display
3 
4 plt.ion()
5 
6 
7 def plot(scores, mean_scores):
8     display.clear_output(wait=True)
9     display.display(plt.gcf())
10     plt.clf()
11     plt.title('Training...')
12     plt.xlabel('Number of Games')
13     plt.ylabel('Score')
14     plt.plot(scores)
15     plt.plot(mean_scores)
16     plt.ylim(ymin=0)
17     plt.text(len(scores) - 1, scores[-1], str(scores[-1]))
18     plt.text(len(mean_scores) - 1, mean_scores[-1], str(mean_scores[-1]))
19     plt.show(block=False)
20     plt.pause(.1)
21 
 
1 import torch
2 import torch.nn as nn
3 import torch.optim as optim
4 import torch.nn.functional as F
5 import numpy as np
6 import os
7 
8 class Linear_QNet(nn.Module):
9     def __init__(self, input_size, hidden_size1, output_size):
10         super().__init__()
11         self.linear1 = nn.Linear(input_size, hidden_size1)
12         self.linear2 = nn.Linear(hidden_size1, output_size)
13 
14     def forward(self, x):
15         x = F.relu(self.linear1(x))
16         x = self.linear2(x)
17         return x
18 
19     def save(self, file_name='model.pth'):
20         model_folder_path = './model'
21         if not os.path.exists(model_folder_path):
22             os.makedirs(model_folder_path)
23 
24         file_name = os.path.join(model_folder_path, file_name)
25         torch.save(self.state_dict(), file_name)
26 
27 
28 class QTrainer:
29     def __init__(self, model, lr, gamma):
30         self.lr = lr
31         self.gamma = gamma
32         self.model = model
33         self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
34         self.criterion = nn.MSELoss()
35 
36     def train_step(self, state, action, reward, next_state, done):
37         state = np.array(state)
38         state = torch.tensor(state, dtype=torch.float)
39         next_state = np.array(next_state)
40         next_state = torch.tensor(next_state, dtype=torch.float)
41         action = np.array(action)
42         action = torch.tensor(action, dtype=torch.long)
43         reward = np.array(reward)
44         reward = torch.tensor(reward, dtype=torch.float)
45         # (n, x)
46 
47         if len(state.shape) == 1:
48             # (1, x)
49             state = torch.unsqueeze(state, 0)
50             next_state = torch.unsqueeze(next_state, 0)
51             action = torch.unsqueeze(action, 0)
52             reward = torch.unsqueeze(reward, 0)
53             done = (done,)
54 
55         # 1: predicted Q values with current state
56         pred = self.model(state)
57 
58         target = pred.clone()
59         for idx in range(len(done)):
60             Q_new = reward[idx]
61             if not done[idx]:
62                 Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
63 
64             target[idx][torch.argmax(action[idx]).item()] = Q_new
65 
66         # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done
67         # pred.clone()
68         # preds[argmax(action)] = Q_new
69         self.optimizer.zero_grad()
70         loss = self.criterion(target, pred)
71         loss.backward()
72 
73         self.optimizer.step()
74