amélioration du système de sauvegarde de l'ia

This commit is contained in:
2024-08-20 17:28:41 +02:00
parent a8f0e7cdca
commit 4064bde716
9 changed files with 241 additions and 133 deletions

View File

@@ -103,8 +103,6 @@ Pour que le programme comprenne le fichier xml, il doit y avoir des balises spé
</Configuration> </Configuration>
``` ```
``
# JEU et GAMEPLAY # JEU et GAMEPLAY
Ce jeu est un 1vs1 snake tactique tour par tour avec une gestion de mur et de fruits (que l'on peut ajouter aléatoirement ou en le directement en le placant par x et y), nous pouvons se déplacer avec les touches **z q s d ou/et w a s d**, le jeu se termine quand l'un des 2 snake meurt soit en foncant dans un corps, soit par un mur. Ce jeu est un 1vs1 snake tactique tour par tour avec une gestion de mur et de fruits (que l'on peut ajouter aléatoirement ou en le directement en le placant par x et y), nous pouvons se déplacer avec les touches **z q s d ou/et w a s d**, le jeu se termine quand l'un des 2 snake meurt soit en foncant dans un corps, soit par un mur.
@@ -165,7 +163,7 @@ Ce calcul sera la valeur de toutes les actions que l'IA va enregistrer dans sa b
## Resultat : ## Resultat :
![IA](res/video/ia_solo_15min_apprentissage.gif) ![IA](res/ia_solo_15min_apprentissage.gif)
Dans cette vidéo, l'ia s'est entrainé pendant 15min tout seul et il a trouvé le meilleur chemin selon lui jusqu'à sa derniere erreur sur son apprentissage. Dans cette vidéo, l'ia s'est entrainé pendant 15min tout seul et il a trouvé le meilleur chemin selon lui jusqu'à sa derniere erreur sur son apprentissage.
si je l'apprennais encore un peu plus, il pourra rester le plus longtemps possible. si je l'apprennais encore un peu plus, il pourra rester le plus longtemps possible.

View File

Before

Width:  |  Height:  |  Size: 816 KiB

After

Width:  |  Height:  |  Size: 816 KiB

Binary file not shown.

Binary file not shown.

View File

@@ -2,6 +2,7 @@ import configuration.ConfigGame;
import game.Terminal; import game.Terminal;
import game.environnement.*; import game.environnement.*;
import personnage.*; import personnage.*;
import personnage.IAQLearning.QTable;
public class Main { public class Main {
/** /**
@@ -26,7 +27,6 @@ public class Main {
* Pour la QTable, il est préférable de créer une variable avec la * Pour la QTable, il est préférable de créer une variable avec la
* déclaration de la classe : * déclaration de la classe :
* QTable qtable = new QTable(); * QTable qtable = new QTable();
*
*/ */
public static void main(String[] args) { public static void main(String[] args) {
@@ -36,6 +36,8 @@ public class Main {
Map map = config.getMap(); Map map = config.getMap();
Personnage.n = config.getN(); Personnage.n = config.getN();
QTable.folderStorage = 1000;
if (args.length < 1) { new Terminal(map, personnages).run(); } // lancer en local if (args.length < 1) { new Terminal(map, personnages).run(); } // lancer en local
else if (args.length == 2) { new Terminal(map, personnages).run(args[0], args[1]); } // lancer en ligne else if (args.length == 2) { new Terminal(map, personnages).run(args[0], args[1]); } // lancer en ligne
else { System.err.println("WARNING: vous avez mis un mauvais nombre d'argument"); } // erreur else { System.err.println("WARNING: vous avez mis un mauvais nombre d'argument"); } // erreur

View File

@@ -94,7 +94,7 @@ public class ConfigGame {
return personnages; return personnages;
} }
private Personnage choosePersonnage(String name, HashMap<String, String> information) throws Error { private Personnage choosePersonnage(String name, HashMap<String, String> information) {
int[] coordinate = new int[] { int[] coordinate = new int[] {
Integer.parseInt(information.get("x")), Integer.parseInt(information.get("x")),
Integer.parseInt(information.get("y")), Integer.parseInt(information.get("y")),
@@ -105,7 +105,7 @@ public class ConfigGame {
case "robot": return new Robot(information.get("name"), coordinate); case "robot": return new Robot(information.get("name"), coordinate);
case "ia": { case "ia": {
String path = information.get("QTable"); String path = information.get("QTable");
return new IA(coordinate, (path.equals("")) ? new QTable("res/save/") : new QTable(path), name); return new IA(coordinate, new QTable(path.equals("") ? "res/save/" : path, name), name);
} }
default: { default: {

View File

@@ -1,8 +1,16 @@
package personnage.IAQLearning; package personnage.IAQLearning;
import java.io.*; import java.io.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import personnage.types.Mouvement; import personnage.types.Mouvement;
@@ -20,6 +28,9 @@ public class QTable {
* necessaire pour que le bot puisse faire des actions. * necessaire pour que le bot puisse faire des actions.
*/ */
private HashMap<Actions, Double> qValues; private HashMap<Actions, Double> qValues;
public static int folderStorage = 1000;
private ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
/** /**
* Constructeur de la classe QTabl cree le HashMap qValues. * Constructeur de la classe QTabl cree le HashMap qValues.
@@ -32,9 +43,13 @@ public class QTable {
* Constructeur de la classe QTable cree le HashMap qValues et mets dans la liste * Constructeur de la classe QTable cree le HashMap qValues et mets dans la liste
* les informations du fichier dans le path. * les informations du fichier dans le path.
*/ */
public QTable(String pathFile) { public QTable(String pathFile, String pathname) {
qValues = new HashMap<>(); qValues = new HashMap<>();
getValues(pathFile); try {
get(pathFile, pathname);
} catch (ClassNotFoundException | IOException e) {
e.printStackTrace();
}
} }
/** /**
@@ -62,7 +77,15 @@ public class QTable {
* Cette méthode sauvegarde les valeurs Q dans un fichier spécifié. * Cette méthode sauvegarde les valeurs Q dans un fichier spécifié.
* @param path le chemin du fichier où sauvegarder les données * @param path le chemin du fichier où sauvegarder les données
*/ */
public void save(String path) { public void saveChunk(HashMap<Actions, Double> hashmapSlide, String path) {
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path))) {
oos.writeObject(hashmapSlide);
} catch (IOException e) {
e.printStackTrace();
}
}
public void saveChunk(String path) {
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path))) { try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path))) {
oos.writeObject(qValues); oos.writeObject(qValues);
} catch (IOException e) { } catch (IOException e) {
@@ -73,14 +96,89 @@ public class QTable {
/** /**
* Cette méthode charge les valeurs Q depuis un fichier spécifié. * Cette méthode charge les valeurs Q depuis un fichier spécifié.
* @param path le chemin du fichier à partir duquel charger les données * @param path le chemin du fichier à partir duquel charger les données
* @throws ClassNotFoundException
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void getValues(String path) { public void getChunk(String path) throws IOException, ClassNotFoundException {
try(ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) { try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) {
qValues = (HashMap<Actions, Double>) ois.readObject(); qValues = (HashMap<Actions, Double>) ois.readObject();
}
}
} catch (IOException | ClassNotFoundException e) { @SuppressWarnings("unchecked")
save(path); private HashMap<Actions, Double> getChunkSave(String path) throws IOException, ClassNotFoundException {
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) {
return (HashMap<Actions, Double>)ois.readObject();
}
}
public void save(String pathFolderName, String name) {
File file = new File(pathFolderName);
if (name == null) name = "null";
if (file.isFile()) {
saveChunk(pathFolderName);
} else {
List<Map.Entry<Actions, Double>> entryList = new ArrayList<>(qValues.entrySet());
int indexFile = 0;
for (int i = 0; i < entryList.size(); i += folderStorage) {
int end = Math.min(i + folderStorage, entryList.size());
List<Map.Entry<Actions, Double>> subList = entryList.subList(i, end);
HashMap<Actions, Double> subHashMap = new HashMap<>();
for(Map.Entry<Actions, Double> subValue : subList) {
subHashMap.put(subValue.getKey(), subValue.getValue());
}
String fileName = pathFolderName + File.separator + name + "_part" + (++indexFile) + ".ser";
executor.submit(() -> saveChunk(subHashMap, fileName));
}
executor.shutdown();
try {
executor.awaitTermination(1, TimeUnit.HOURS);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public void get(String pathFolderName, String name) throws ClassNotFoundException, IOException {
File file = new File(pathFolderName);
if (file.isFile()) {
getChunk(pathFolderName);
} else if (!(file.exists() && file.isDirectory())) {
System.err.println("Erreur : le fichier " + pathFolderName + " n'existe pas.");
System.exit(-1);
} else {
// les hashmaps basique ne supporte pas bien le multithread en java >> https://www.geeksforgeeks.org/concurrenthashmap-in-java/
ConcurrentHashMap<Actions, Double> multithreadHashMap = new ConcurrentHashMap<>();
File[] listFiles = file.listFiles((dir, filename) -> filename.startsWith(name) && filename.endsWith(".ser"));
if (listFiles != null) {
for (File partFile : listFiles) {
executor.submit(() -> {
try {
// Charger chaque fichier dans un HashMap temporaire
HashMap<Actions, Double> tempMap = getChunkSave(partFile.getPath());
// Ajouter chaque entrée au ConcurrentHashMap
multithreadHashMap.putAll(tempMap);
} catch (ClassNotFoundException | IOException e) {
e.printStackTrace();
}
});
}
executor.shutdown();
try {
executor.awaitTermination(1, TimeUnit.HOURS);
} catch (InterruptedException e) {
e.printStackTrace();
}
qValues = new HashMap<>(multithreadHashMap);
}
} }
} }
@@ -88,7 +186,7 @@ public class QTable {
* cette méthode renvoie dans le terminal tout les elements du * cette méthode renvoie dans le terminal tout les elements du
* hashmap. * hashmap.
*/ */
public void printValues() { public void printHashMap() {
for (Map.Entry<Actions, Double> value : qValues.entrySet()) { for (Map.Entry<Actions, Double> value : qValues.entrySet()) {
System.out.println(value.getKey().toString() + " -> " + value.getValue()); System.out.println(value.getKey().toString() + " -> " + value.getValue());
} }

View File

@@ -1,154 +1,154 @@
package tests; // package tests;
import java.io.File; // import java.io.File;
import java.util.Arrays; // import java.util.Arrays;
import game.environnement.Map; // import game.environnement.Map;
import personnage.IA; // import personnage.IA;
import personnage.Personnage; // import personnage.Personnage;
import personnage.IAQLearning.QTable; // import personnage.IAQLearning.QTable;
import personnage.IAQLearning.State; // import personnage.IAQLearning.State;
import personnage.types.Mouvement; // import personnage.types.Mouvement;
public class IATest { // public class IATest {
private final static String path1 = "res" + File.separator + // private final static String path1 = "res" + File.separator +
"save" + File.separator + // "save" + File.separator +
"learn1.ser"; // "learn1.ser";
private final static String path2 = "res" + File.separator + // private final static String path2 = "res" + File.separator +
"save" + File.separator + // "save" + File.separator +
"learn2.ser"; // "learn2.ser";
public static void learnIA() { // public static void learnIA() {
double alpha = 0.1; // double alpha = 0.1;
double gamma = 0.9; // double gamma = 0.9;
double epsilon = 1.0; // double epsilon = 1.0;
double decay_rate = 0.995; // double decay_rate = 0.995;
double minEpsilon = 0.01; // double minEpsilon = 0.01;
int totalEpisodes = 200; // int totalEpisodes = 200;
Personnage.n = 4; // Personnage.n = 4;
for(int episode = 0; episode < totalEpisodes; episode++) { // for(int episode = 0; episode < totalEpisodes; episode++) {
QTable qTable = new QTable(); // QTable qTable = new QTable();
IA iaqLearning = new IA(new int[] {2, 2}, qTable, alpha, gamma, epsilon, null); // IA iaqLearning = new IA(new int[] {2, 2}, qTable, alpha, gamma, epsilon, null);
Map map = new Map(12, 22); // Map map = new Map(12, 22);
qTable.getValues(path1); // qTable.get(path1);
while (true) { // while (true) {
Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length); // Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length);
mapIA.replaceGrid(map.getGrid()); // mapIA.replaceGrid(map.getGrid());
map.placePersonnages(iaqLearning); // map.placePersonnages(iaqLearning);
State currentState = iaqLearning.getCurrentState(map.getGrid()); // State currentState = iaqLearning.getCurrentState(map.getGrid());
Mouvement mouvement = iaqLearning.bestMouvement(currentState); // Mouvement mouvement = iaqLearning.bestMouvement(currentState);
iaqLearning.moveSnake(mouvement); // iaqLearning.moveSnake(mouvement);
int[] coordinate = iaqLearning.getHeadCoordinate(); // int[] coordinate = iaqLearning.getHeadCoordinate();
if(map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) { // if(map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) {
iaqLearning.receiveReward(currentState, mouvement, -1.0, currentState); // iaqLearning.receiveReward(currentState, mouvement, -1.0, currentState);
break; // break;
} // }
mapIA.placePersonnages(iaqLearning); // mapIA.placePersonnages(iaqLearning);
State nextState = iaqLearning.getCurrentState(mapIA.getGrid()); // State nextState = iaqLearning.getCurrentState(mapIA.getGrid());
iaqLearning.receiveReward(currentState, mouvement, 0.1, nextState); // iaqLearning.receiveReward(currentState, mouvement, 0.1, nextState);
iaqLearning.increaseRound(); // iaqLearning.increaseRound();
mapIA.clearMap(); // mapIA.clearMap();
map.clearMap(); // map.clearMap();
} // }
qTable.save(path1); // qTable.save(path1);
epsilon = Math.max(minEpsilon, epsilon * decay_rate); // epsilon = Math.max(minEpsilon, epsilon * decay_rate);
System.out.println("Episode : " + episode + " | States : " + qTable.getqValues().size()); // System.out.println("Episode : " + episode + " | States : " + qTable.getqValues().size());
} // }
} // }
public static void learnIAvsIA() { // public static void learnIAvsIA() {
double alpha = 0.9; // double alpha = 0.9;
double gamma = 0.9; // double gamma = 0.9;
double epsilon = 0.1; // double epsilon = 0.1;
int maxEpisode = 1000000; // int maxEpisode = 1000000;
Personnage.n = 4; // Personnage.n = 4;
QTable qTable1 = new QTable(); // QTable qTable1 = new QTable();
qTable1.getValues(path1); // qTable1.get(path1);
QTable qTable2 = new QTable(); // QTable qTable2 = new QTable();
qTable2.getValues(path2); // qTable2.get(path2);
for (int episode = 0; episode < maxEpisode; episode++) { // for (int episode = 0; episode < maxEpisode; episode++) {
Map map = new Map(12, 22); // Map map = new Map(12, 22);
IA[] iaqLearnings = new IA[] { // IA[] iaqLearnings = new IA[] {
new IA(new int[] {2, 2}, qTable1, alpha, gamma, epsilon, null), // new IA(new int[] {2, 2}, qTable1, alpha, gamma, epsilon, null),
new IA(new int[] {9, 19}, qTable2, alpha, gamma, epsilon, null), // new IA(new int[] {9, 19}, qTable2, alpha, gamma, epsilon, null),
}; // };
boolean isGameOver = false; // boolean isGameOver = false;
while(true) { // while(true) {
for (int personnages = 0; personnages < iaqLearnings.length; personnages++) { // for (int personnages = 0; personnages < iaqLearnings.length; personnages++) {
IA iaqLearning = iaqLearnings[personnages]; // IA iaqLearning = iaqLearnings[personnages];
Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length); // Map mapIA = new Map(map.getGrid()[0].length, map.getGrid().length);
for (IA value : iaqLearnings) { // for (IA value : iaqLearnings) {
map.placePersonnages(value); // map.placePersonnages(value);
} // }
State currentState = iaqLearning.getCurrentState(map.getGrid()); // State currentState = iaqLearning.getCurrentState(map.getGrid());
Mouvement mouvement = iaqLearning.bestMouvement(currentState); // Mouvement mouvement = iaqLearning.bestMouvement(currentState);
iaqLearning.moveSnake(mouvement); // iaqLearning.moveSnake(mouvement);
int[] coordinate = iaqLearning.getHeadCoordinate(); // int[] coordinate = iaqLearning.getHeadCoordinate();
if (map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) { // if (map.isGameOver(coordinate) || iaqLearning.applyEffects(map.getEffect(coordinate))) {
iaqLearning.receiveReward(currentState, mouvement, -1000, currentState); // iaqLearning.receiveReward(currentState, mouvement, -1000, currentState);
isGameOver = true; // isGameOver = true;
break; // break;
} // }
int value = (personnages + 1) % 2; // int value = (personnages + 1) % 2;
for (int[] snakeCoordinate : iaqLearnings[value].getCoordinate()) { // for (int[] snakeCoordinate : iaqLearnings[value].getCoordinate()) {
if (Arrays.equals(coordinate, snakeCoordinate)) { // if (Arrays.equals(coordinate, snakeCoordinate)) {
iaqLearnings[value].receiveReward(currentState, mouvement, 1000, currentState); // iaqLearnings[value].receiveReward(currentState, mouvement, 1000, currentState);
iaqLearning.receiveReward(currentState, mouvement, -500, currentState); // iaqLearning.receiveReward(currentState, mouvement, -500, currentState);
isGameOver = true; // isGameOver = true;
break; // break;
} // }
} // }
mapIA.placePersonnages(iaqLearning); // mapIA.placePersonnages(iaqLearning);
State nextState = iaqLearning.getCurrentState(mapIA.getGrid()); // State nextState = iaqLearning.getCurrentState(mapIA.getGrid());
iaqLearning.receiveReward(currentState, mouvement, -0.01, nextState); // iaqLearning.receiveReward(currentState, mouvement, -0.01, nextState);
iaqLearning.increaseRound(); // iaqLearning.increaseRound();
mapIA.clearMap(); // mapIA.clearMap();
map.clearMap(); // map.clearMap();
System.out.println("States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size()); // System.out.println("States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size());
} // }
if(isGameOver) break; // if(isGameOver) break;
} // }
System.out.println(" States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size() + "Episode: " + episode); // System.out.println(" States 1: " + qTable1.getqValues().size() + " States 2: " + qTable2.getqValues().size() + "Episode: " + episode);
} // }
qTable1.save(path1); // qTable1.save(path1);
} // }
} // }

View File

@@ -1,6 +1,7 @@
package tests; package tests;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@@ -11,8 +12,7 @@ import personnage.types.Mouvement;
public class QTableTest { public class QTableTest {
private final static String path = "res" + File.separator + private final static String path = "res" + File.separator +
"save" + File.separator + "save" + File.separator;
"test.ser";
public static void searchValue() { public static void searchValue() {
QTable qTable = new QTable(); QTable qTable = new QTable();
@@ -21,24 +21,30 @@ public class QTableTest {
qTable.setQValue(state, mouvement, 10.2); qTable.setQValue(state, mouvement, 10.2);
qTable.printValues(); qTable.printHashMap();
System.out.println(qTable.getQValue(state, mouvement)); // Devrait retourner 10.2 System.out.println(qTable.getQValue(state, mouvement)); // Devrait retourner 10.2
} }
public static void writeValueFile() { public static void writeValueFile() {
QTable.folderStorage = 1;
QTable qTable = new QTable(); QTable qTable = new QTable();
State state = new State(new Grid[3][3], new ArrayList<>()); State state = new State(new Grid[3][3], new ArrayList<>());
qTable.setQValue(state, Mouvement.BAS, 10.3); qTable.setQValue(state, Mouvement.BAS, 10.3);
qTable.save(path); // Devrait sauvegarder dans test.ser le state avec la valeur 10.3
qTable.setQValue(new State(new Grid[3][3], new ArrayList<>()), Mouvement.HAUT, 12.3);
qTable.save(path, "name");
} }
public static void searchValueFile() { public static void searchValueFile() {
QTable qTable = new QTable(); QTable qTable = new QTable();
qTable.getValues(path); // qTable.get(path);
State state = new State(new Grid[3][3], new ArrayList<>()); State state = new State(new Grid[3][3], new ArrayList<>());
qTable.printValues(); qTable.printHashMap();
System.out.println(qTable.getQValue(state, Mouvement.BAS)); // Devrait retourner 10.3 System.out.println(qTable.getQValue(state, Mouvement.BAS)); // Devrait retourner 10.3
} }
@@ -55,9 +61,13 @@ public class QTableTest {
QTable qTableReceived = new QTable(); QTable qTableReceived = new QTable();
qTableSend.setQValue(state, mouvement, 102.0); qTableSend.setQValue(state, mouvement, 102.0);
qTableSend.save(path); qTableSend.save(path, "fromage");
qTableReceived.getValues(path); try {qTableReceived.get(path, "fromage");} catch(ClassNotFoundException | IOException e) {e.printStackTrace();}
System.out.println(qTableReceived.getQValue(state, mouvement)); System.out.println(qTableReceived.getQValue(state, mouvement));
} }
public static void main(String[] args) {
getRealInformation();
}
} }