Matrix Calculator
can do matrix addition, substraction, multiplication, transposing, and inverse
Matrix Calculator
can do matrix addition, substraction, multiplication, transposing, and inverse
1 //34:33 2 import java.util.Scanner; 3 4 public class array2DComplexMatrix { 5 public static void main (String [] Args) { 6 //------------------------------------------- variables -------------------------------------------// 7 Scanner input = new Scanner (System.in); 8 int columnA; 9 int columnB = 0; 10 int rowA = 0; 11 int rowB = 0; 12 int num; 13 String ans; 14 boolean valid = false; 15 16 17 //------------------------------ getting the size of the matrices ---------------------------------// 18 do { 19 System.out.println("columns in matrix A"); 20 columnA = checkint(input, true, 0); 21 22 System.out.println("rows in matrix A"); 23 rowA = checkint(input, true, 0); 24 25 System.out.println("columns in matrix B"); 26 columnB = checkint(input, true, 0); 27 28 System.out.println("rows in matrix B"); 29 rowB = checkint(input, true, 0); 30 31 System.out.println("columnA: " + columnA + "\nrowA: " + rowA + "\ncolumnB: " + columnB + "\nrowB: " + rowB); 32 33 if (columnA == rowB || (columnA == columnB && rowA == rowB) || columnA == rowA || columnB == rowB) 34 valid = true; 35 else 36 System.out.println("matrices must be of same dimension, square, or column of matrix 1 must match row of matrix 2"); 37 38 //or else user will be stuck in a loop later on in 39 } while (!valid); 40 41 42 43 44 //------------------------------ making + filling the matrices ---------------------------------// 45 int [] [] matrixA = new int [columnA] [rowA]; 46 int [] [] matrixB = new int [columnB] [rowB]; 47 48 49 50 51 for (int x = 0; x < columnA; x ++){ 52 for (int y = 0; y < rowA; y ++){ 53 System.out.println("What number is in position " + (x + 1) + ", " + (y + 1) + " for matrix A?"); 54 matrixA [x] [y] = checkint(input, false, 0); 55 } 56 } 57 printMatrix("Matrix A", matrixA, rowA, columnA); 58 59 for (int x = 0; x < columnB; x ++){ 60 for (int y = 0; y < rowB; y ++){ 61 System.out.println("What number is in position " + (x + 1) + ", " + (y + 1) + " for matrix B?"); 62 matrixB [x] [y] = checkint(input, false, 0); 63 } 64 } 65 printMatrix("Matrix B", matrixB, rowB, columnB); 66 67 68 //test 69 //inverseMatrix(matrixA, columnA); 70 //determinantMatrix(matrixA, columnA); 71 //adjointMatrix(matrixA, columnA); 72 73 74 //------------------------------------------- what operation? -------------------------------------------// 75 System.out.println("What operation?"); 76 valid = false; 77 do { 78 ans = input.nextLine(); 79 80 if (ans.equalsIgnoreCase("addition") || ans.equalsIgnoreCase ("+")){ //would use switch case but it doesn't work on my stupid computer 81 if (columnA == columnB && rowA == rowB){ 82 addMatrix(matrixA, matrixB, columnA, rowA); 83 valid = true; 84 } 85 else 86 System.out.println("matrices must be of same dimensions"); 87 } 88 //addition 89 90 else if (ans.equalsIgnoreCase("substraction") || ans.equalsIgnoreCase("-")){ 91 if (columnA == columnB && rowA == rowB){ 92 subMatrix(matrixA, matrixB, columnA, rowA); 93 valid = true; 94 } 95 else 96 System.out.println("matrices must be of same dimensions"); 97 } 98 //substraction 99 100 101 else if (ans.equalsIgnoreCase("multiplication") || ans.equalsIgnoreCase("*") || ans.equalsIgnoreCase("x")){ 102 System.out.println("Scalar?"); 103 ans = input.nextLine(); 104 if (ans.equalsIgnoreCase("yes")){ 105 System.out.println("Matrix A or B?"); 106 ans = input.nextLine(); 107 if (ans.equalsIgnoreCase("Matrix A") || ans.equalsIgnoreCase("A")){ 108 System.out.println("by what number?"); 109 num = checkint(input, false, 0); 110 scalarMultiplication(matrixA, columnA, rowA, num ); 111 valid = true; 112 } 113 else if (ans.equalsIgnoreCase("Matrix B") || ans.equalsIgnoreCase("B")){ 114 System.out.println("by what number?"); 115 num = checkint(input, false, 0); 116 scalarMultiplication(matrixA, columnA, rowA, num ); 117 118 valid = true; 119 } 120 } 121 else if (ans.equalsIgnoreCase("no")){ 122 if (columnA == rowB){ 123 multiplyMatrix(matrixA, matrixB, columnA, rowA, columnB); 124 valid = true; 125 } 126 else 127 System.out.println("columns in first matrix must match rows of second"); 128 } 129 } 130 //multiplication 131 132 else if (ans.equalsIgnoreCase("inverse") || ans.equalsIgnoreCase("^-1")){ 133 System.out.println("matrix A or B?"); 134 ans = input.nextLine(); 135 if (ans.equalsIgnoreCase("matrix A") || ans.equalsIgnoreCase("A") || ans.equalsIgnoreCase("matrixA")){ 136 if (columnA == rowA){ 137 inverseMatrix(matrixA, columnA); 138 valid = true; 139 } 140 else 141 System.out.println("matrix must be square"); 142 } 143 if (ans.equalsIgnoreCase("matrix B") || ans.equalsIgnoreCase("B") || ans.equalsIgnoreCase("matrixB")){ 144 if (columnB == rowB){ 145 inverseMatrix(matrixB, columnB); 146 valid = true; 147 } 148 else 149 System.out.println("matrix must be square"); 150 } 151 } 152 //inverse 153 154 155 else { 156 System.out.println("invalid"); 157 //invalid 158 } 159 } while (!valid); 160 161 162 } 163 //------------------------------------------- method for checking for valid inputted integers -------------------------------------------// 164 public static int checkint (Scanner input, boolean isThereMin, int min){ 165 int target; 166 do { 167 while (!input.hasNextInt()){ 168 System.out.println("invalid input. Try again."); 169 input.nextLine(); 170 } 171 target = input.nextInt(); 172 if (target <= min && isThereMin == true) 173 System.out.println("invalid input. Try again."); 174 }while (target <= min && isThereMin == true); 175 return target; 176 } 177 178 //------------------------------------------- method for printing matrices -------------------------------------------// 179 public static void printMatrix (String label, int [] [] array, int row, int column){ 180 System.out.println(label); 181 for (int x = 0; x < row; x ++){ 182 for (int y = 0; y < column; y ++){ 183 System.out.print(array[y][x] + " "); 184 } 185 System.out.print("\n"); 186 } 187 } 188 189 //------------------------------------------- method for adding matrices -------------------------------------------// 190 public static void addMatrix (int [] [] arrayA, int [] [] arrayB, int column, int row){ 191 int [] [] arrayAB = new int [column] [row]; 192 193 for (int x = 0; x < row; x++){ 194 for (int y = 0; y < column; y++){ 195 arrayAB [x] [y] = arrayA [x] [y] + arrayB [x] [y]; 196 } 197 } 198 printMatrix("Matrix A + Matrix B", arrayAB, row, column); 199 200 } 201 //------------------------------------------- method for substracting matrices -------------------------------------------// 202 public static void subMatrix (int [] [] arrayA, int [] [] arrayB, int column, int row){ 203 int [] [] arrayAB = new int [column] [row]; 204 205 for (int x = 0; x < row; x++){ 206 for (int y = 0; y < column; y++){ 207 arrayAB [x] [y] = arrayA [x] [y] - arrayB [x] [y]; 208 } 209 } 210 printMatrix("Matrix A - Matrix B", arrayAB, row, column); 211 } 212 //------------------------------------------- method for multiplying matrices -------------------------------------------// 213 214 public static void multiplyMatrix (int [] [] arrayA, int [] [] arrayB, int columnA, int rowA, int columnB){ 215 216 int [] [] arrayAB = new int [columnB] [rowA]; 217 218 for (int x = 0; x < columnB; x ++){ 219 for (int y = 0; y < rowA; y ++){ 220 for (int z = 0; z < columnA; z++){ 221 arrayAB [x] [y] += arrayA[z][y] * arrayB [x][z]; 222 } 223 } 224 } 225 printMatrix("Matrix A * Matrix B", arrayAB, rowA, columnB); 226 } 227 //------------------------------------------- method for scalar multiplying -------------------------------------------// 228 public static void scalarMultiplication (int [] [] array, int column, int row, int num){ 229 int [] [] arrayFin = new int [column] [row]; 230 231 for (int x = 0; x < column; x++){ 232 for (int y = 0; y < row; y++) { 233 arrayFin [x] [y] = array [x] [y] * num; 234 } 235 } 236 printMatrix("Matrix", arrayFin, row, column ); 237 238 } 239 //------------------------------------------- method for transposing matrices -------------------------------------------// 240 public static int [] [] transpose (int [] [] array, int column, int row){ 241 int [] [] transposed = new int [row] [column]; 242 243 for (int x = 0; x < column; x ++){ 244 for (int y = 0; y < row; y ++){ 245 transposed [y] [x] = array [x] [y]; 246 } 247 } 248 printMatrix("transposed", transposed, row, column); 249 return transposed; 250 } 251 252 253 //------------------------------------------- method for finding inverse of matrices -------------------------------------------// 254 public static void inverseMatrix (int [] [] array, int column){ 255 double det; 256 det = determinantMatrix(array, column); 257 det = 1/det; 258 if (det == 0){ 259 System.out.println("impossible. Determinant is 0"); 260 } 261 else{ 262 double [] [] inverse = new double [column] [column]; 263 int [] [] adjoint = adjointMatrix(array, column); 264 265 for (int x = 0; x < column; x ++){ 266 for (int y = 0; y < column; y++){ 267 inverse [x] [y] = det * adjoint [x] [y]; 268 } 269 } 270 271 System.out.println("inverse Matrix"); 272 for (int x = 0; x < column; x ++){ 273 for (int y = 0; y < column; y ++){ 274 System.out.print(inverse[y][x] + " "); 275 } 276 System.out.print("\n"); 277 } 278 } 279 280 } 281 282 public static int [] [] adjointMatrix (int [] [] array, int column){ 283 int [] [] adjoint = new int [column] [column]; 284 int [] [] arraySml; 285 286 for (int x = 0; x < column; x ++){ 287 for (int y = 0; y < column; y ++) { //rows 288 arraySml = findSmallerArray(array, column, x, y); 289 //printMatrix("small array", arraySml, (column-1), (column -1)); 290 adjoint [x] [y] = determinantMatrix(arraySml, (column-1)); 291 if ((x + y)%2 != 0) 292 adjoint [x] [y] *= -1; 293 //System.out.println(adjoint[x][y]); 294 } 295 } 296 adjoint = transpose(adjoint,column,column); 297 298 printMatrix ("adjoint", adjoint, column, column); 299 return adjoint; 300 301 } 302 303 public static int determinantMatrix(int [] [] array, int column){ 304 int det = 0; 305 int num; 306 307 if (array.length == 2){ 308 det = (array [0][0]*array[1][1]) - (array [1][0] * array[0][1]); 309 } 310 else 311 { 312 313 for (int x = 0; x < column; x ++){ 314 int [] [] arraySml = findSmallerArray(array, column, x, 0); 315 //printMatrix ("smaller matrix", arraySml, (column - 1), (column - 1)); //ignore 316 if (arraySml.length == 2){ 317 //System.out.println("smaller matrix is 2... determinant is array sml wtv wtv"); 318 if ((x + 1)%2 != 0) 319 det += (array[x] [0])* ((arraySml [0][0]*arraySml[1][1]) - (arraySml [1][0] * arraySml[0][1])); 320 else 321 det -= (array [x] [0])*((arraySml [0][0]*arraySml[1][1]) - (arraySml [1][0] * arraySml[0][1])); 322 } 323 else if ((x + 1)%2 != 0){ 324 num = array [x] [0] * determinantMatrix(arraySml, (column - 1)); 325 System.out.println("num " + num); 326 det += num; 327 } 328 else if ((x + 1)%2 == 0){ 329 num = array [x][0] * determinantMatrix(arraySml, (column-1)); 330 //System.out.println("num " + num); 331 det -= num; 332 } 333 334 //System.out.println("final calculation: " + det); 335 } 336 } 337 //System.out.println("det" + det); 338 return det; 339 340 } 341 342 public static int [] [] findSmallerArray(int [] [] array, int column, int targetColumn, int targetRow){ //cofactor 343 int [] [] arraySml = new int [column -1] [column - 1]; 344 int rowCnt = 0; 345 int columnCnt = 0; 346 347 for (int x = 0; x < column - 1 ; x++){ //columns 348 if (columnCnt == targetColumn) 349 columnCnt++; 350 for (int y = 0; y < column - 1; y ++){ //rows 351 if (rowCnt == targetRow) 352 rowCnt++; 353 arraySml [x] [y] = array [columnCnt] [rowCnt]; 354 rowCnt++; 355 } 356 rowCnt = 0; 357 columnCnt++; 358 } 359 return arraySml; 360 } 361 362 363 }
Pong AI
Reinforcement learning agent that learns to play pong
1 import pygame 2 import sys 3 from pygame.math import Vector2 4 5 pygame.init() 6 7 screen = pygame.display.set_mode((640, 480)) 8 clock = pygame.time.Clock() 9 SCREEN_UPDATE = pygame.USEREVENT 10 pygame.time.set_timer(SCREEN_UPDATE, 150) 11 12 13 class pongAI: 14 def __init__(self): 15 self.paddle1 = Vector2(25, 0) 16 self.paddle2 = Vector2(615, 0) 17 self.ball = Vector2(320, 240) 18 self.ball_velocity = Vector2(1, 1) 19 self.score = 0 20 self.last_paddle = None 21 self.reward = 0 22 23 def play_step(self, action): 24 # if game is closed 25 for event in pygame.event.get(): 26 if event.type == pygame.QUIT: 27 pygame.quit() 28 sys.exit() 29 30 # to update the visuals 31 if event.type == SCREEN_UPDATE: 32 self.update_ui() 33 34 # to move paddles 35 # pressed = pygame.key.get_pressed() 36 # if pressed[pygame.K_w]: 37 # self.paddle1.y = max(0, int(self.paddle1.y - 5)) 38 # if pressed[pygame.K_s]: 39 # self.paddle1.y = min(380, int(self.paddle1.y + 5)) 40 # if pressed[pygame.K_UP]: 41 # self.paddle2.y = max(0, int(self.paddle2.y - 5)) 42 # if pressed[pygame.K_DOWN]: 43 # self.paddle2.y = min(380, int(self.paddle2.y + 5)) 44 45 # to move paddles with action [w , s , up, down] 46 if action == [1, 0, 0, 0]: 47 old = self.paddle1.y 48 self.paddle1.y = max(0, int(self.paddle1.y - 5)) 49 if abs(self.paddle1.y - self.ball.y) < abs(old - self.ball.y): 50 self.reward = 1 51 if action == [0, 1, 0, 0]: 52 old = self.paddle1.y 53 self.paddle1.y = min(380, int(self.paddle1.y + 5)) 54 if abs(self.paddle1.y - self.ball.y) < abs(old - self.ball.y): 55 self.reward = 1 56 if action == [0, 0, 1, 0]: 57 old = self.paddle2.y 58 self.paddle2.y = max(0, int(self.paddle2.y - 5)) 59 if abs(self.paddle2.y - self.ball.y) < abs(old - self.ball.y): 60 self.reward = 1 61 if action == [0, 0, 0, 1]: 62 old = self.paddle2.y 63 self.paddle2.y = min(380, int(self.paddle2.y + 5)) 64 if abs(self.paddle2.y - self.ball.y) < abs(old - self.ball.y): 65 self.reward = 1 66 67 # to move ball 68 self.ball += self.ball_velocity 69 if self.ball.y >= 475 or self.ball.y <= 0: 70 self.ball_velocity.y *= -1 71 72 # check if ball is close to paddle 73 74 75 # check if ball hit paddle 76 if self.check_paddle_hit(): 77 self.ball_velocity.x *= -1 78 self.score += 1 79 self.reward = 10 80 print(self.score) 81 if self.score % 4 == 0 and self.score != 0: 82 print(self.ball_velocity) 83 self.ball_velocity.x += 1 84 self.ball_velocity.y += 1 85 print(self.ball_velocity) 86 87 done = self.check_game_over() 88 if done: 89 self.reset() 90 self.reward = -10 91 92 pygame.display.flip() 93 clock.tick(60) 94 score = self.score 95 reward = self.reward 96 97 return reward, done, score 98 99 def reset(self): 100 self.ball = Vector2(320, 240) 101 self.ball_velocity = Vector2(1, 1) 102 self.paddle1 = Vector2(25, 190) 103 self.paddle2 = Vector2(615, 190) 104 self.score = 0 105 self.last_paddle = None 106 107 def check_game_over(self): 108 if self.ball.x > 640 or self.ball.x < 0: 109 return True 110 return False 111 112 def check_paddle_hit(self): 113 # Create rectangles for the ball and paddles 114 ball_rect = pygame.Rect(self.ball.x - 5, self.ball.y - 5, 10, 10) 115 paddle1_rect = pygame.Rect(self.paddle1.x, self.paddle1.y, 15, 100) 116 paddle2_rect = pygame.Rect(self.paddle2.x, self.paddle2.y, 15, 100) 117 118 # Check for collision between ball and paddles 119 if ball_rect.colliderect(paddle1_rect): 120 if self.last_paddle == "paddle1_rect": 121 return False 122 else: 123 self.last_paddle = "paddle1_rect" 124 return True 125 126 if ball_rect.colliderect(paddle2_rect): 127 if self.last_paddle == "paddle2_rect": 128 return False 129 else: 130 self.last_paddle = "paddle2_rect" 131 return True 132 133 return False 134 135 def update_ui(self): 136 screen.fill(pygame.Color("Black")) 137 paddle1_rect = pygame.Rect(self.paddle1.x, self.paddle1.y, 10, 100) 138 paddle2_rect = pygame.Rect(self.paddle2.x, self.paddle2.y, 10, 100) 139 pygame.draw.circle(screen, pygame.Color("White"), [self.ball.x, self.ball.y], 5, 0) 140 pygame.draw.rect(screen, pygame.Color("White"), paddle1_rect) 141 pygame.draw.rect(screen, pygame.Color("White"), paddle2_rect) 142 143 144 pong = pongAI() 145
1 import torch 2 import random 3 import numpy as np 4 from collections import deque 5 import sys 6 from model import Linear_QNet, QTrainer 7 from main import pongAI 8 from helper import plot 9 10 MAX_MEMORY = 1000 11 BATCH_SIZE = 100 12 LR = 0.001 13 14 15 class Agent: 16 def __init__(self): 17 self.n_games = 0 18 self.epsilon = 0 # randomness 19 self.gamma = 0.9 # discount rate 20 self.memory = deque(maxlen=MAX_MEMORY) # like a list but will only have x amount of items before it refills 21 self.model = Linear_QNet(8, 256, 4) 22 self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma) 23 24 def get_state(self, game): 25 state = [ 26 # get ball position [x + y values] 27 game.ball.x, 28 game.ball.y, 29 30 # get ball velocity [x + y values] 31 game.ball_velocity.x, 32 game.ball_velocity.y, 33 34 # get paddle position 1 [x + y values] 35 game.paddle1.x, 36 game.paddle1.y, 37 38 # get paddle position 2 [x + y values] 39 game.paddle2.x, 40 game.paddle2.y 41 ] 42 return np.array(state, dtype=np.float32) 43 44 def remember(self, state, action, reward, next_state, done): 45 self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached 46 47 def train_long_memory(self): 48 if len(self.memory) > BATCH_SIZE: 49 mini_sample = random.sample(self.memory, BATCH_SIZE) # makes a batch using the memory and randoms 50 else: 51 mini_sample = self.memory 52 53 states, actions, rewards, next_states, dones = zip(*mini_sample) 54 self.trainer.train_step(states, actions, rewards, next_states, dones) 55 # for state, action, reward, next_state, done in mini_sample: 56 # self.trainer.train_step(state, action, reward, next_state, done) 57 58 def train_short_memory(self, state, action, reward, next_state, done): 59 self.trainer.train_step(state, action, reward, next_state, done) 60 61 def get_action(self, state): 62 # return the move it's going to take [w , s , up, down] 63 self.epsilon = max(5, 90 - self.n_games) 64 final_move = [0, 0, 0, 0] 65 if random.randint(0, 200) < self.epsilon: 66 move = random.randint(0, 3) 67 final_move[move] = 1 68 else: 69 state0 = torch.tensor(state, dtype=torch.float) 70 prediction = self.model(state0) 71 move = torch.argmax(prediction).item() 72 final_move[move] = 1 73 74 return final_move 75 76 77 def train(): 78 plot_scores = [] 79 plot_mean_scores = [] 80 total_score = 0 81 record = 0 82 agent = Agent() 83 game = pongAI() 84 while True: 85 # get old state 86 state_old = agent.get_state(game) 87 88 # get move 89 final_move = agent.get_action(state_old) 90 91 # perform move and get new state 92 reward, done, score = game.play_step(final_move) 93 print("score is" + score) 94 state_new = agent.get_state(game) 95 96 # train short memory 97 agent.train_short_memory(state_old, final_move, reward, state_new, done) 98 99 # remember 100 agent.remember(state_old, final_move, reward, state_new, done) 101 if done: 102 # train long memory, plot result 103 game.reset() 104 agent.n_games += 1 105 agent.train_long_memory() 106 107 if score > record: 108 record = score 109 agent.model.save() 110 111 print('Game', agent.n_games, 'Score', score, 'Record:', record) 112 113 plot_scores.append(score) 114 total_score += score 115 mean_score = total_score / agent.n_games 116 plot_mean_scores.append(mean_score) 117 plot(plot_scores, plot_mean_scores) 118 119 120 while True: 121 train() 122
1 import matplotlib.pyplot as plt 2 from IPython import display 3 4 plt.ion() 5 6 7 def plot(scores, mean_scores): 8 display.clear_output(wait=True) 9 display.display(plt.gcf()) 10 plt.clf() 11 plt.title('Training...') 12 plt.xlabel('Number of Games') 13 plt.ylabel('Score') 14 plt.plot(scores) 15 plt.plot(mean_scores) 16 plt.ylim(ymin=0) 17 plt.text(len(scores) - 1, scores[-1], str(scores[-1])) 18 plt.text(len(mean_scores) - 1, mean_scores[-1], str(mean_scores[-1])) 19 plt.show(block=False) 20 plt.pause(.1) 21
1 import torch 2 import torch.nn as nn 3 import torch.optim as optim 4 import torch.nn.functional as F 5 import numpy as np 6 import os 7 8 class Linear_QNet(nn.Module): 9 def __init__(self, input_size, hidden_size1, output_size): 10 super().__init__() 11 self.linear1 = nn.Linear(input_size, hidden_size1) 12 self.linear2 = nn.Linear(hidden_size1, output_size) 13 14 def forward(self, x): 15 x = F.relu(self.linear1(x)) 16 x = self.linear2(x) 17 return x 18 19 def save(self, file_name='model.pth'): 20 model_folder_path = './model' 21 if not os.path.exists(model_folder_path): 22 os.makedirs(model_folder_path) 23 24 file_name = os.path.join(model_folder_path, file_name) 25 torch.save(self.state_dict(), file_name) 26 27 28 class QTrainer: 29 def __init__(self, model, lr, gamma): 30 self.lr = lr 31 self.gamma = gamma 32 self.model = model 33 self.optimizer = optim.Adam(model.parameters(), lr=self.lr) 34 self.criterion = nn.MSELoss() 35 36 def train_step(self, state, action, reward, next_state, done): 37 state = np.array(state) 38 state = torch.tensor(state, dtype=torch.float) 39 next_state = np.array(next_state) 40 next_state = torch.tensor(next_state, dtype=torch.float) 41 action = np.array(action) 42 action = torch.tensor(action, dtype=torch.long) 43 reward = np.array(reward) 44 reward = torch.tensor(reward, dtype=torch.float) 45 # (n, x) 46 47 if len(state.shape) == 1: 48 # (1, x) 49 state = torch.unsqueeze(state, 0) 50 next_state = torch.unsqueeze(next_state, 0) 51 action = torch.unsqueeze(action, 0) 52 reward = torch.unsqueeze(reward, 0) 53 done = (done,) 54 55 # 1: predicted Q values with current state 56 pred = self.model(state) 57 58 target = pred.clone() 59 for idx in range(len(done)): 60 Q_new = reward[idx] 61 if not done[idx]: 62 Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) 63 64 target[idx][torch.argmax(action[idx]).item()] = Q_new 65 66 # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done 67 # pred.clone() 68 # preds[argmax(action)] = Q_new 69 self.optimizer.zero_grad() 70 loss = self.criterion(target, pred) 71 loss.backward() 72 73 self.optimizer.step() 74