Source code for discrete_optimization.generic_tools.result_storage.resultcomparator

#  Copyright (c) 2022 AIRBUS and its affiliates.
#  This source code is licensed under the MIT license found in the
#  LICENSE file in the root directory of this source tree.

import logging
import math
from typing import Dict, List, Optional, Tuple, cast

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from discrete_optimization.generic_tools.do_problem import (
    Problem,
    Solution,
    TupleFitness,
)
from discrete_optimization.generic_tools.plot_utils import get_cmap_with_nb_colors
from discrete_optimization.generic_tools.result_storage.result_storage import (
    ParetoFront,
    ResultStorage,
    fitness_class,
    plot_pareto_2d,
    result_storage_to_pareto_front,
)

logger = logging.getLogger(__name__)


[docs] class ResultComparator: # If test problem is None, then we use the fitnesses from the ResultStorage def __init__( self, list_result_storage: List[ResultStorage], result_storage_names: List[str], objectives_str: List[str], objective_weights: List[int], test_problems: Optional[List[Problem]] = None, ): self.list_result_storage = list_result_storage self.result_storage_names = result_storage_names self.objectives_str = objectives_str self.objective_weights = objective_weights self.test_problems = test_problems self.reevaluated_results: Dict[int, Dict[str, List[float]]] = {} if self.test_problems is not None: self.reevaluate_result_storages()
[docs] def reevaluate_result_storages(self) -> None: if self.test_problems is None: raise RuntimeError( "self.test_problems cannot be None when calling reevaluate_result_storages()." ) for res in self.list_result_storage: self.reevaluated_results[self.list_result_storage.index(res)] = {} for obj in self.objectives_str: self.reevaluated_results[self.list_result_storage.index(res)][obj] = [] for scenario in self.test_problems: best_sol = res.get_best_solution() if best_sol is None: raise RuntimeError( "res.get_best_solution() cannot be None " "for any res in self.list_result_storage" "when calling reevaluate_result_storages()." ) best_sol.change_problem(scenario) val = scenario.evaluate(best_sol)[obj] self.reevaluated_results[self.list_result_storage.index(res)][ obj ].append(val) logger.debug(f"reevaluated_results: {self.reevaluated_results}")
[docs] def plot_distribution_for_objective(self, objective_str: str) -> Figure: fig, ax = plt.subplots(1, figsize=(10, 10)) for i in range(len(self.result_storage_names)): sns.distplot( self.reevaluated_results[i][objective_str], rug=True, bins=max(1, len(self.reevaluated_results[i][objective_str]) // 10), label=self.result_storage_names[i], ax=ax, ) ax.legend() ax.set_title( objective_str.upper() + " distribution over test instances, for different optimisation approaches" ) return fig
[docs] def print_test_distribution(self) -> None: ...
[docs] def get_best_by_objective_by_result_storage( self, objectif_str: str ) -> Dict[str, Tuple[Solution, fitness_class]]: obj_index = self.objectives_str.index(objectif_str) val: Dict[str, Tuple[Solution, fitness_class]] = {} for i in range(len(self.list_result_storage)): fit_array = [ cast( TupleFitness, fitness ).vector_fitness[ # indicate to mypy that we are in multiobjective case obj_index ] for solution, fitness in self.list_result_storage[i].list_solution_fits ] # create fit array if self.list_result_storage[i].maximize: best_fit = max(fit_array) else: best_fit = min(fit_array) best_index = fit_array.index(best_fit) best_sol = self.list_result_storage[i].list_solution_fits[best_index] val[self.result_storage_names[i]] = best_sol return val
[docs] def generate_super_pareto(self) -> ParetoFront: sols = [] for rs in self.list_result_storage: for s in rs.list_solution_fits: sols.append(s) rs = ResultStorage(list_solution_fits=sols, best_solution=None) pareto_store = result_storage_to_pareto_front(result_storage=rs, problem=None) return pareto_store
[docs] def plot_all_2d_paretos_single_plot( self, objectives_str: Optional[List[str]] = None ) -> Axes: if objectives_str is None: objecives_names = self.objectives_str[:2] objectives_index = [0, 1] else: objecives_names = objectives_str objectives_index = [] for obj in objectives_str: obj_index = self.objectives_str.index(obj) objectives_index.append(obj_index) colors = get_cmap_with_nb_colors( color_map_str="rainbow", nb_colors=len(self.list_result_storage) ) fig, ax = plt.subplots(1) ax.set_xlabel(objecives_names[0]) ax.set_ylabel(objecives_names[1]) for i in range(len(self.list_result_storage)): ax.scatter( x=[ p[1].vector_fitness[objectives_index[0]] # type: ignore for p in self.list_result_storage[i].list_solution_fits ], y=[ p[1].vector_fitness[objectives_index[1]] # type: ignore for p in self.list_result_storage[i].list_solution_fits ], color=colors[i], ) ax.legend(self.result_storage_names) return ax
[docs] def plot_all_2d_paretos_subplots( self, objectives_str: Optional[List[str]] = None ) -> Figure: if objectives_str is None: objectives_names = self.objectives_str[:2] objectives_index = [0, 1] else: objectives_names = objectives_str objectives_index = [] for obj in objectives_str: obj_index = self.objectives_str.index(obj) objectives_index.append(obj_index) cols = 2 rows = math.ceil( len(self.list_result_storage) / cols ) # I have to do this to ensure at least 2 rows or else it creates axs with only 1 diumension and it crashes fig, axs = plt.subplots(rows, cols) axis = axs.flatten() colors = get_cmap_with_nb_colors( color_map_str="rainbow", nb_colors=len(self.list_result_storage) ) for i, ax in zip( range(len(self.list_result_storage)), axis[: len(self.list_result_storage)] ): x = [ p[1].vector_fitness[objectives_index[0]] # type: ignore for p in self.list_result_storage[i].list_solution_fits ] y = [ p[1].vector_fitness[objectives_index[1]] # type: ignore for p in self.list_result_storage[i].list_solution_fits ] ax.scatter(x=x, y=y, color=colors[i]) ax.set_title(self.result_storage_names[i]) fig.tight_layout(pad=3.0) return fig
[docs] def plot_super_pareto(self) -> None: super_pareto = self.generate_super_pareto() plot_pareto_2d(pareto_front=super_pareto, name_axis=self.objectives_str) plt.title("Pareto front obtained by merging solutions from all result stores")
[docs] def plot_all_best_by_objective(self, objectif_str: str) -> None: obj_index = self.objectives_str.index(objectif_str) data = self.get_best_by_objective_by_result_storage(objectif_str) x = list(data.keys()) y = [data[key][1].vector_fitness[obj_index] for key in x] # type: ignore y_pos = np.arange(len(x)) plt.bar(y_pos, y) plt.xticks(y_pos, x, rotation=45) plt.title("Comparison on " + objectif_str)