Source code for tommy.controller.visualizations.correlation_matrix_creator

import matplotlib.figure
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator

from tommy.controller.result_interfaces.correlation_matrix_interface import (
    CorrelationMatrixInterface)
from tommy.controller.topic_modelling_runners.abstract_topic_runner import (
    TopicRunner)
from tommy.controller.visualizations.abstract_visualization import (
    AbstractVisualization)
from tommy.controller.visualizations.possible_visualization import VisGroup
from tommy.controller.visualizations.visualization_input_datatypes import (
    VisInputData)


[docs] class CorrelationMatrixCreator(AbstractVisualization): """ A class for (checking the required interfaces for and) constructing a correlation matrix plot for topics in the given topic runner and returning it as a matplotlib figure. """ _required_interfaces = [CorrelationMatrixInterface, TopicRunner] name = 'Correlatiematrix topics' short_tab_name = 'Correlatie' vis_group = VisGroup.MODEL needed_input_data = [VisInputData.PROCESSED_CORPUS] def _create_figure(self, topic_runner: TopicRunner | CorrelationMatrixInterface, **kwargs ) -> matplotlib.figure.Figure: """ Construct a correlation matrix plot for the topics in the given topic runner and return it as a matplotlib figure. :param topic_runner: The topic model to construct the plot for. This should implement the CorrelationMatrixInterface :return: Matplotlib figure showing a correlation matrix. The matrix entries are labeled by their topic_id. A value close to 0 means the topics are similar and close to 1 means very different. """ # Construct the correlation matrix correlation_matrix = topic_runner.get_correlation_matrix( n_words_to_process=30) # Construct a plot and axes fig, ax = plt.subplots() # Construct the correlations matrix adding colors data = ax.imshow(correlation_matrix, cmap='Blues', vmin=0, vmax=1, origin='lower') # Add a color bar to the plot plt.colorbar(data) # Add a title and correct integer ticks on both axes plt.title(self.name, pad=25) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Adjust the plot ticks so that they start from 1 instead of 0 num_topics = topic_runner.get_n_topics() plt.xticks(np.arange(num_topics), np.arange(1, num_topics + 1)) plt.yticks(np.arange(num_topics), np.arange(1, num_topics + 1)) fig.figure.subplots_adjust( left=0.15, right=0.85, top=0.85, bottom=0.15) plt.close() return fig
""" This program has been developed by students from the bachelor Computer Science at Utrecht University within the Software Project course. © Copyright Utrecht University (Department of Information and Computing Sciences) """