diff --git a/README.md b/README.md index 1becaaf..5fb0326 100644 --- a/README.md +++ b/README.md @@ -103,8 +103,6 @@ Pour que le programme comprenne le fichier xml, il doit y avoir des balises spé ``` -`` - # JEU et GAMEPLAY Ce jeu est un 1vs1 snake tactique tour par tour avec une gestion de mur et de fruits (que l'on peut ajouter aléatoirement ou en le directement en le placant par x et y), nous pouvons se déplacer avec les touches **z q s d ou/et w a s d**, le jeu se termine quand l'un des 2 snake meurt soit en foncant dans un corps, soit par un mur. @@ -165,7 +163,7 @@ Ce calcul sera la valeur de toutes les actions que l'IA va enregistrer dans sa b ## Resultat : -![IA](res/video/ia_solo_15min_apprentissage.gif) +![IA](res/ia_solo_15min_apprentissage.gif) Dans cette vidéo, l'ia s'est entrainé pendant 15min tout seul et il a trouvé le meilleur chemin selon lui jusqu'à sa derniere erreur sur son apprentissage. si je l'apprennais encore un peu plus, il pourra rester le plus longtemps possible. diff --git a/res/video/ia_solo_15min_apprentissage.gif b/res/ia_solo_15min_apprentissage.gif similarity index 100% rename from res/video/ia_solo_15min_apprentissage.gif rename to res/ia_solo_15min_apprentissage.gif diff --git a/res/save/learn1.ser b/res/save/learn1.ser deleted file mode 100644 index c597512..0000000 Binary files a/res/save/learn1.ser and /dev/null differ diff --git a/res/save/learn2.ser b/res/save/learn2.ser deleted file mode 100644 index c597512..0000000 Binary files a/res/save/learn2.ser and /dev/null differ diff --git a/src/Main.java b/src/Main.java index e2a1628..46dd453 100644 --- a/src/Main.java +++ b/src/Main.java @@ -2,6 +2,7 @@ import configuration.ConfigGame; import game.Terminal; import game.environnement.*; import personnage.*; +import personnage.IAQLearning.QTable; public class Main { /** @@ -26,7 +27,6 @@ public class Main { * Pour la QTable, il est préférable de créer une variable avec la * déclaration de la classe : * QTable qtable = new QTable(); - * */ public static void main(String[] args) { @@ -36,6 +36,8 @@ public class Main { Map map = config.getMap(); Personnage.n = config.getN(); + QTable.folderStorage = 1000; + if (args.length < 1) { new Terminal(map, personnages).run(); } // lancer en local else if (args.length == 2) { new Terminal(map, personnages).run(args[0], args[1]); } // lancer en ligne else { System.err.println("WARNING: vous avez mis un mauvais nombre d'argument"); } // erreur diff --git a/src/configuration/ConfigGame.java b/src/configuration/ConfigGame.java index 71750ad..aff67fc 100644 --- a/src/configuration/ConfigGame.java +++ b/src/configuration/ConfigGame.java @@ -94,7 +94,7 @@ public class ConfigGame { return personnages; } - private Personnage choosePersonnage(String name, HashMap information) throws Error { + private Personnage choosePersonnage(String name, HashMap information) { int[] coordinate = new int[] { Integer.parseInt(information.get("x")), Integer.parseInt(information.get("y")), @@ -105,7 +105,7 @@ public class ConfigGame { case "robot": return new Robot(information.get("name"), coordinate); case "ia": { String path = information.get("QTable"); - return new IA(coordinate, (path.equals("")) ? new QTable("res/save/") : new QTable(path), name); + return new IA(coordinate, new QTable(path.equals("") ? "res/save/" : path, name), name); } default: { diff --git a/src/personnage/IAQLearning/QTable.java b/src/personnage/IAQLearning/QTable.java index d024a4f..6572331 100644 --- a/src/personnage/IAQLearning/QTable.java +++ b/src/personnage/IAQLearning/QTable.java @@ -1,8 +1,16 @@ package personnage.IAQLearning; import java.io.*; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import personnage.types.Mouvement; @@ -20,6 +28,9 @@ public class QTable { * necessaire pour que le bot puisse faire des actions. */ private HashMap qValues; + public static int folderStorage = 1000; + + private ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); /** * Constructeur de la classe QTabl cree le HashMap qValues. @@ -32,9 +43,13 @@ public class QTable { * Constructeur de la classe QTable cree le HashMap qValues et mets dans la liste * les informations du fichier dans le path. */ - public QTable(String pathFile) { + public QTable(String pathFile, String pathname) { qValues = new HashMap<>(); - getValues(pathFile); + try { + get(pathFile, pathname); + } catch (ClassNotFoundException | IOException e) { + e.printStackTrace(); + } } /** @@ -62,7 +77,15 @@ public class QTable { * Cette méthode sauvegarde les valeurs Q dans un fichier spécifié. * @param path le chemin du fichier où sauvegarder les données */ - public void save(String path) { + public void saveChunk(HashMap hashmapSlide, String path) { + try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path))) { + oos.writeObject(hashmapSlide); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public void saveChunk(String path) { try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path))) { oos.writeObject(qValues); } catch (IOException e) { @@ -73,22 +96,97 @@ public class QTable { /** * Cette méthode charge les valeurs Q depuis un fichier spécifié. * @param path le chemin du fichier à partir duquel charger les données + * @throws ClassNotFoundException */ @SuppressWarnings("unchecked") - public void getValues(String path) { - try(ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) { + public void getChunk(String path) throws IOException, ClassNotFoundException { + try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) { qValues = (HashMap) ois.readObject(); + } + } - } catch (IOException | ClassNotFoundException e) { - save(path); - } + @SuppressWarnings("unchecked") + private HashMap getChunkSave(String path) throws IOException, ClassNotFoundException { + try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) { + return (HashMap)ois.readObject(); + } + } + + public void save(String pathFolderName, String name) { + File file = new File(pathFolderName); + if (name == null) name = "null"; + + if (file.isFile()) { + saveChunk(pathFolderName); + } else { + List> entryList = new ArrayList<>(qValues.entrySet()); + int indexFile = 0; + + for (int i = 0; i < entryList.size(); i += folderStorage) { + int end = Math.min(i + folderStorage, entryList.size()); + List> subList = entryList.subList(i, end); + + HashMap subHashMap = new HashMap<>(); + for(Map.Entry subValue : subList) { + subHashMap.put(subValue.getKey(), subValue.getValue()); + } + + String fileName = pathFolderName + File.separator + name + "_part" + (++indexFile) + ".ser"; + + executor.submit(() -> saveChunk(subHashMap, fileName)); + } + + executor.shutdown(); + try { + executor.awaitTermination(1, TimeUnit.HOURS); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + + public void get(String pathFolderName, String name) throws ClassNotFoundException, IOException { + File file = new File(pathFolderName); + + if (file.isFile()) { + getChunk(pathFolderName); + } else if (!(file.exists() && file.isDirectory())) { + System.err.println("Erreur : le fichier " + pathFolderName + " n'existe pas."); + System.exit(-1); + } else { + // les hashmaps basique ne supporte pas bien le multithread en java >> https://www.geeksforgeeks.org/concurrenthashmap-in-java/ + ConcurrentHashMap multithreadHashMap = new ConcurrentHashMap<>(); + File[] listFiles = file.listFiles((dir, filename) -> filename.startsWith(name) && filename.endsWith(".ser")); + if (listFiles != null) { + for (File partFile : listFiles) { + executor.submit(() -> { + try { + // Charger chaque fichier dans un HashMap temporaire + HashMap tempMap = getChunkSave(partFile.getPath()); + // Ajouter chaque entrée au ConcurrentHashMap + multithreadHashMap.putAll(tempMap); + } catch (ClassNotFoundException | IOException e) { + e.printStackTrace(); + } + }); + } + executor.shutdown(); + try { + executor.awaitTermination(1, TimeUnit.HOURS); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + qValues = new HashMap<>(multithreadHashMap); + } + } } /** * cette méthode renvoie dans le terminal tout les elements du * hashmap. */ - public void printValues() { + public void printHashMap() { for (Map.Entry value : qValues.entrySet()) { System.out.println(value.getKey().toString() + " -> " + value.getValue()); } diff --git a/src/tests/IATest.java b/src/tests/IATest.java index f490462..14b3aae 100644 --- a/src/tests/IATest.java +++ b/src/tests/IATest.java @@ -1,154 +1,154 @@ -package tests; +// package tests; -import java.io.File; -import java.util.Arrays; +// import java.io.File; +// import java.util.Arrays; -import game.environnement.Map; -import personnage.IA; -import personnage.Personnage; -import personnage.IAQLearning.QTable; -import personnage.IAQLearning.State; -import personnage.types.Mouvement; +// import game.environnement.Map; +// import personnage.IA; +// import personnage.Personnage; +// import personnage.IAQLearning.QTable; +// import personnage.IAQLearning.State; +// import personnage.types.Mouvement; -public class IATest { - private final static String path1 = "res" + File.separator + - "save" + File.separator + - "learn1.ser"; +// public class IATest { +// private final static String path1 = "res" + File.separator + +// "save" + File.separator + +// "learn1.ser"; - private final static String path2 = "res" + File.separator + - "save" + File.separator + - "learn2.ser"; +// private final static String path2 = "res" + File.separator + +// "save" + File.separator + +// "learn2.ser"; - public static void learnIA() { - double alpha = 0.1; - double gamma = 0.9; - double epsilon = 1.0; - double decay_rate = 0.995; - double minEpsilon = 0.01; +// public static void learnIA() { +// double alpha = 0.1; +// double gamma = 0.9; +// double epsilon = 1.0; +// double decay_rate = 0.995; +// double minEpsilon = 0.01; - int totalEpisodes = 200; +// int totalEpisodes = 200; - Personnage.n = 4; +// Personnage.n = 4; - for(int episode = 0; episode < totalEpisodes; episode++) { - QTable qTable = new QTable(); - IA iaqLearning = new IA(new int[] {2, 2}, qTable, alpha, gamma, epsilon, null); - Map map = new Map(12, 22); +// for(int episode = 0; episode < totalEpisodes; episode++) { +// QTable qTable = new QTable(); +// IA iaqLearning = new IA(new int[] {2, 2}, qTable, alpha, gamma, epsilon, null); +// Map map = new Map(12, 22); - qTable.getValues(path1); +// qTable.get(path1); - while (true) { - Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length); - mapIA.replaceGrid(map.getGrid()); +// while (true) { +// Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length); +// mapIA.replaceGrid(map.getGrid()); - map.placePersonnages(iaqLearning); +// map.placePersonnages(iaqLearning); - State currentState = iaqLearning.getCurrentState(map.getGrid()); - Mouvement mouvement = iaqLearning.bestMouvement(currentState); +// State currentState = iaqLearning.getCurrentState(map.getGrid()); +// Mouvement mouvement = iaqLearning.bestMouvement(currentState); - iaqLearning.moveSnake(mouvement); +// iaqLearning.moveSnake(mouvement); - int[] coordinate = iaqLearning.getHeadCoordinate(); +// int[] coordinate = iaqLearning.getHeadCoordinate(); - if(map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) { - iaqLearning.receiveReward(currentState, mouvement, -1.0, currentState); - break; - } +// if(map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) { +// iaqLearning.receiveReward(currentState, mouvement, -1.0, currentState); +// break; +// } - mapIA.placePersonnages(iaqLearning); +// mapIA.placePersonnages(iaqLearning); - State nextState = iaqLearning.getCurrentState(mapIA.getGrid()); +// State nextState = iaqLearning.getCurrentState(mapIA.getGrid()); - iaqLearning.receiveReward(currentState, mouvement, 0.1, nextState); - iaqLearning.increaseRound(); +// iaqLearning.receiveReward(currentState, mouvement, 0.1, nextState); +// iaqLearning.increaseRound(); - mapIA.clearMap(); - map.clearMap(); - } +// mapIA.clearMap(); +// map.clearMap(); +// } - qTable.save(path1); +// qTable.save(path1); - epsilon = Math.max(minEpsilon, epsilon * decay_rate); - System.out.println("Episode : " + episode + " | States : " + qTable.getqValues().size()); - } - } +// epsilon = Math.max(minEpsilon, epsilon * decay_rate); +// System.out.println("Episode : " + episode + " | States : " + qTable.getqValues().size()); +// } +// } - public static void learnIAvsIA() { - double alpha = 0.9; - double gamma = 0.9; - double epsilon = 0.1; +// public static void learnIAvsIA() { +// double alpha = 0.9; +// double gamma = 0.9; +// double epsilon = 0.1; - int maxEpisode = 1000000; +// int maxEpisode = 1000000; - Personnage.n = 4; +// Personnage.n = 4; - QTable qTable1 = new QTable(); - qTable1.getValues(path1); +// QTable qTable1 = new QTable(); +// qTable1.get(path1); - QTable qTable2 = new QTable(); - qTable2.getValues(path2); +// QTable qTable2 = new QTable(); +// qTable2.get(path2); - for (int episode = 0; episode < maxEpisode; episode++) { - Map map = new Map(12, 22); +// for (int episode = 0; episode < maxEpisode; episode++) { +// Map map = new Map(12, 22); - IA[] iaqLearnings = new IA[] { - new IA(new int[] {2, 2}, qTable1, alpha, gamma, epsilon, null), - new IA(new int[] {9, 19}, qTable2, alpha, gamma, epsilon, null), - }; +// IA[] iaqLearnings = new IA[] { +// new IA(new int[] {2, 2}, qTable1, alpha, gamma, epsilon, null), +// new IA(new int[] {9, 19}, qTable2, alpha, gamma, epsilon, null), +// }; - boolean isGameOver = false; +// boolean isGameOver = false; - while(true) { - for (int personnages = 0; personnages < iaqLearnings.length; personnages++) { - IA iaqLearning = iaqLearnings[personnages]; - Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length); +// while(true) { +// for (int personnages = 0; personnages < iaqLearnings.length; personnages++) { +// IA iaqLearning = iaqLearnings[personnages]; +// Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length); - for (IA value : iaqLearnings) { - map.placePersonnages(value); - } +// for (IA value : iaqLearnings) { +// map.placePersonnages(value); +// } - State currentState = iaqLearning.getCurrentState(map.getGrid()); - Mouvement mouvement = iaqLearning.bestMouvement(currentState); +// State currentState = iaqLearning.getCurrentState(map.getGrid()); +// Mouvement mouvement = iaqLearning.bestMouvement(currentState); - iaqLearning.moveSnake(mouvement); +// iaqLearning.moveSnake(mouvement); - int[] coordinate = iaqLearning.getHeadCoordinate(); +// int[] coordinate = iaqLearning.getHeadCoordinate(); - if (map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) { - iaqLearning.receiveReward(currentState, mouvement, -1000, currentState); - isGameOver = true; - break; - } +// if (map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) { +// iaqLearning.receiveReward(currentState, mouvement, -1000, currentState); +// isGameOver = true; +// break; +// } - int value = (personnages + 1) % 2; +// int value = (personnages + 1) % 2; - for (int[] snakeCoordinate : iaqLearnings[value].getCoordinate()) { - if (Arrays.equals(coordinate, snakeCoordinate)) { - iaqLearnings[value].receiveReward(currentState, mouvement, 1000, currentState); - iaqLearning.receiveReward(currentState, mouvement, -500, currentState); +// for (int[] snakeCoordinate : iaqLearnings[value].getCoordinate()) { +// if (Arrays.equals(coordinate, snakeCoordinate)) { +// iaqLearnings[value].receiveReward(currentState, mouvement, 1000, currentState); +// iaqLearning.receiveReward(currentState, mouvement, -500, currentState); - isGameOver = true; - break; - } - } +// isGameOver = true; +// break; +// } +// } - mapIA.placePersonnages(iaqLearning); +// mapIA.placePersonnages(iaqLearning); - State nextState = iaqLearning.getCurrentState(mapIA.getGrid()); - iaqLearning.receiveReward(currentState, mouvement, -0.01, nextState); +// State nextState = iaqLearning.getCurrentState(mapIA.getGrid()); +// iaqLearning.receiveReward(currentState, mouvement, -0.01, nextState); - iaqLearning.increaseRound(); +// iaqLearning.increaseRound(); - mapIA.clearMap(); - map.clearMap(); +// mapIA.clearMap(); +// map.clearMap(); - System.out.println("States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size()); - } +// System.out.println("States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size()); +// } - if(isGameOver) break; - } - System.out.println(" States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size() + "Episode: " + episode); - } - qTable1.save(path1); - } -} +// if(isGameOver) break; +// } +// System.out.println(" States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size() + "Episode: " + episode); +// } +// qTable1.save(path1); +// } +// } diff --git a/src/tests/QTableTest.java b/src/tests/QTableTest.java index f6eb48d..f6cb54a 100644 --- a/src/tests/QTableTest.java +++ b/src/tests/QTableTest.java @@ -1,6 +1,7 @@ package tests; import java.io.File; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -11,8 +12,7 @@ import personnage.types.Mouvement; public class QTableTest { private final static String path = "res" + File.separator + - "save" + File.separator + - "test.ser"; + "save" + File.separator; public static void searchValue() { QTable qTable = new QTable(); @@ -21,24 +21,30 @@ public class QTableTest { qTable.setQValue(state, mouvement, 10.2); - qTable.printValues(); + qTable.printHashMap(); System.out.println(qTable.getQValue(state, mouvement)); // Devrait retourner 10.2 } public static void writeValueFile() { + QTable.folderStorage = 1; QTable qTable = new QTable(); State state = new State(new Grid[3][3], new ArrayList<>()); + qTable.setQValue(state, Mouvement.BAS, 10.3); - qTable.save(path); // Devrait sauvegarder dans test.ser le state avec la valeur 10.3 + + qTable.setQValue(new State(new Grid[3][3], new ArrayList<>()), Mouvement.HAUT, 12.3); + + + qTable.save(path, "name"); } public static void searchValueFile() { QTable qTable = new QTable(); - qTable.getValues(path); + // qTable.get(path); State state = new State(new Grid[3][3], new ArrayList<>()); - qTable.printValues(); + qTable.printHashMap(); System.out.println(qTable.getQValue(state, Mouvement.BAS)); // Devrait retourner 10.3 } @@ -55,9 +61,13 @@ public class QTableTest { QTable qTableReceived = new QTable(); qTableSend.setQValue(state, mouvement, 102.0); - qTableSend.save(path); + qTableSend.save(path, "fromage"); - qTableReceived.getValues(path); + try {qTableReceived.get(path, "fromage");} catch(ClassNotFoundException | IOException e) {e.printStackTrace();} System.out.println(qTableReceived.getQValue(state, mouvement)); } + + public static void main(String[] args) { + getRealInformation(); + } }