Source code for trident_chemwidgets.widgets.scatter

import pandas as pd
import numpy as np
import re
from ipywidgets import DOMWidget
from traitlets import Unicode, Dict, Float, List, Integer, Bool, Any, Union
from .._frontend import module_name, module_version


[docs]class Scatter(DOMWidget): """Plot an interactive scatter plot based on the given data and the selected variables to generate the axis. The scatter plot will be displayed to the left of the cell output, with a molecule gallery displayed to the right. The molecule gallery can show the structures present in the currently-selected subset of the data. Args: data (pd.DataFrame): Data used to generate the scatter plot. smiles (str): Name of the column that contains the SMILES string of each molecule. x (str): Name of the column used to generate the x-axis of the scatter plot. y (str): Name of the column used to generate the y-axis of the scatter plot. hue (str): Name of the column used to color the points of the scatter plot. x_label (str): Label for the x-axis of the histogram, defaults to the value of `x` if not provided. y_label (str): Label for the y-axis of the histogram, defaults to the value of `y` if not provided. hue_label (str): Label for the point colors of the histogram, defaults to the value of `hue` if not provided. x_date_format (str): Date format string to display datetime values on the x axis. y_date_format (str): Date format string to display datetime values on the y axis. Notes: Valid date format strings for the `x_date_format` and `y_date_format` arguments can be found here: https://github.com/d3/d3-time-format#locale_format. For example, a common date format string might be '%Y-%m-%d' to display the 4-digit year, 2-digit month, and 2-digit day (i.e. 2021-12-25). Examples: >>> import trident_chemwidgets as tcw >>> import pandas as pd >>> dataset = pd.read_csv(PATH) >>> scatter = tcw.Scatter(data=dataset, smiles='smiles', x='mwt', y='logp') >>> scatter """ _model_name = Unicode('ScatterModel').tag(sync=True) _model_module = Unicode(module_name).tag(sync=True) _model_module_version = Unicode(module_version).tag(sync=True) _view_name = Unicode('ScatterView').tag(sync=True) _view_module = Unicode(module_name).tag(sync=True) _view_module_version = Unicode(module_version).tag(sync=True) # X-Axis params x_label = Unicode('x').tag(sync=True) x_is_date = Bool(False).tag(sync=True) x_format_date_string = Unicode('').tag(sync=True) # Y-Axis params y_label = Unicode('y').tag(sync=True) y_is_date = Bool(False).tag(sync=True) y_format_date_string = Unicode('').tag(sync=True) # Hue params hue_label = Unicode().tag(sync=True) hue_type = Unicode().tag(sync=True) hue_min = Float().tag(sync=True) hue_max = Float().tag(sync=True) # hue_scale = Unicode('linear').tag(sync=True) data = Dict(per_key_traits={ 'points': List(trait=Dict(per_key_traits={ 'index': Integer(), 'smiles': Unicode(), 'x': Any(), 'y': Any(), })) }).tag(sync=True) savedSelected = List(trait=Integer()).tag(sync=True) def __init__( self, data: pd.DataFrame, smiles: str, x: str, y: str, hue: str = None, x_label: str = None, y_label: str = None, hue_label: str = None, x_date_format: str = None, y_date_format: str = None, **kwargs ): super().__init__(**kwargs) self._smiles_col = smiles self._x_col = x self._y_col = y self._hue = hue if hue else None if self._hue: self.hue_label = hue_label if hue_label else hue self.x_label = x_label if x_label else x self.y_label = y_label if y_label else y self._format_x_date = x_date_format if x_date_format else '' self._format_y_date = y_date_format if y_date_format else '' self._data = data self.data = self.prep_data_for_plot()
[docs] def prep_data_for_plot(self): """Transforms and selects the data correctly for use by the plot. Returns: dict: Data in dict format to be used in plot. """ # Check hue and convert data types if self._hue: data = pd.DataFrame({ 'smiles': self._data[self._smiles_col].values.copy(), 'x': self._data[self._x_col].values.copy(), 'y': self._data[self._y_col].values.copy(), 'hue': self._data[self._hue].values.copy(), }) # Detect the correct type of the hue column self.hue_type = re.sub('[0-9]', '', str(data['hue'].dtype)) # Only use the hue_min and hue_max for the domain in float values if self.hue_type == 'float': self.hue_max = data['hue'].max() self.hue_min = data['hue'].min() else: data = pd.DataFrame({ 'smiles': self._data[self._smiles_col].values.copy(), 'x': self._data[self._x_col].values.copy(), 'y': self._data[self._y_col].values.copy() }) x_type = re.sub('[0-9]', '', str(data['x'].dtype)) if x_type not in ['int', 'float']: # Otherwise verify if x is a date column try: # Try to convert each row to a date data['x'] = pd.to_datetime( data['x']).apply(lambda x: x.__str__()) # Otherwise we can consider that the column contains dates # NOTE: we can't convert to date cause the Vega-side does this once # we declare in the widget component to self.x_is_date = True self.x_format_date_string = self._format_x_date except ValueError: # If raise an exception/error the column cannot be a date type self.x_is_date = False y_type = re.sub('[0-9]', '', str(data['y'].dtype)) if y_type not in ['int', 'float']: # Verify if y is a date column try: data['y'] = pd.to_datetime( data['y']).apply(lambda x: x.__str__()) self.y_is_date = True self.y_format_date_string = self._format_y_date except ValueError: self.y_is_date = False data_list = data.to_dict(orient='records') for i in range(len(data_list)): data_list[i]['index'] = i data_dict = {'points': data_list} return data_dict
@property def selection(self): """Current selection of molecules made by the user. """ return self._data.iloc[self.savedSelected]