-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Add box plot in compare-runs page #6308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1ce62cc
50d14ee
39b2395
50c18a5
86e6692
f8a9360
4bb2476
fa799a7
7b2c99b
20d23e2
a20d670
e17ad71
a173e6a
b8d3e94
0e7bf47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| import React, { useState } from 'react'; | ||
| import { FormattedMessage } from 'react-intl'; | ||
| import PropTypes from 'prop-types'; | ||
| import { Row, Col, Select } from 'antd'; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
import { Select } from '@databricks/design-system';
import { Select as AntSelect } from 'antd';
const { Option } = Select
const { OptGroup } = AntSelect
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use antd's
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed with @hubertzub-db on this yesterday. We recently migrated the contour and scatter plot to |
||
| import { Typography } from '@databricks/design-system'; | ||
| import { RunInfo } from '../sdk/MlflowMessages'; | ||
| import { LazyPlot } from './LazyPlot'; | ||
|
|
||
| const { Option, OptGroup } = Select; | ||
|
|
||
| export const CompareRunBox = ({ runUuids, runInfos, metricLists, paramLists }) => { | ||
| const [xAxis, setXAxis] = useState({ key: undefined, isParam: undefined }); | ||
| const [yAxis, setYAxis] = useState({ key: undefined, isParam: undefined }); | ||
|
|
||
| const paramKeys = Array.from(new Set(paramLists.flat().map(({ key }) => key))).sort(); | ||
| const metricKeys = Array.from(new Set(metricLists.flat().map(({ key }) => key))).sort(); | ||
|
|
||
| const paramOptionPrefix = 'param-'; | ||
| const metricOptionPrefix = 'metric-'; | ||
|
|
||
| const handleXAxisChange = (_, { value, key }) => { | ||
| const isParam = value.startsWith(paramOptionPrefix); | ||
| setXAxis({ key, isParam }); | ||
| }; | ||
|
|
||
| const handleYAxisChange = (_, { value, key }) => { | ||
| const isParam = value.startsWith(paramOptionPrefix); | ||
| setYAxis({ key, isParam }); | ||
| }; | ||
|
|
||
| const renderSelector = (onChange, selectedValue) => ( | ||
| <Select | ||
| css={{ width: '100%', marginBottom: '16px' }} | ||
| placeholder='Select' | ||
| onChange={onChange} | ||
| value={selectedValue} | ||
| > | ||
| <OptGroup label='Parameters' key='parameters'> | ||
| {paramKeys.map((key) => ( | ||
| <Option key={key} value={paramOptionPrefix + key}> | ||
| <div data-test-id='axis-option'>{key}</div> | ||
| </Option> | ||
| ))} | ||
| </OptGroup> | ||
| <OptGroup label='Metrics'> | ||
| {metricKeys.map((key) => ( | ||
| <Option key={key} value={metricOptionPrefix + key}> | ||
| <div data-test-id='axis-option'>{key}</div> | ||
| </Option> | ||
| ))} | ||
| </OptGroup> | ||
| </Select> | ||
| ); | ||
|
|
||
| const getBoxPlotData = () => { | ||
| const data = {}; | ||
| runInfos.forEach((_, index) => { | ||
| const params = paramLists[index]; | ||
| const metrics = metricLists[index]; | ||
| const x = (xAxis.isParam ? params : metrics).find(({ key }) => key === xAxis.key); | ||
| const y = (yAxis.isParam ? params : metrics).find(({ key }) => key === yAxis.key); | ||
| if (x === undefined || y === undefined) { | ||
| return; | ||
| } | ||
|
|
||
| if (x.value in data) { | ||
| data[x.value].push(y.value); | ||
| } else { | ||
| data[x.value] = [y.value]; | ||
| } | ||
| }); | ||
|
|
||
| return Object.entries(data).map(([key, values]) => ({ | ||
| y: values, | ||
| type: 'box', | ||
| name: key, | ||
| jitter: 0.3, | ||
| pointpos: -1.5, | ||
| boxpoints: 'all', | ||
| })); | ||
| }; | ||
|
|
||
| const renderPlot = () => { | ||
| if (!(xAxis.key && yAxis.key)) { | ||
| return ( | ||
| <div | ||
| css={{ | ||
| display: 'flex', | ||
| width: '100%', | ||
| height: '100%', | ||
| justifyContent: 'center', | ||
| alignItems: 'center', | ||
| }} | ||
| > | ||
| <Typography.Text size='xl'> | ||
| <FormattedMessage | ||
| defaultMessage='Select parameters/metrics to plot.' | ||
| description='Text to show when x or y axis is not selected on box plot' | ||
| /> | ||
| </Typography.Text> | ||
| </div> | ||
| ); | ||
| } | ||
|
|
||
| return ( | ||
| <LazyPlot | ||
| css={{ | ||
| width: '100%', | ||
| height: '100%', | ||
| minHeight: '35vw', | ||
| }} | ||
| data={getBoxPlotData()} | ||
| layout={{ | ||
| margin: { | ||
| t: 30, | ||
| }, | ||
| hovermode: 'closest', | ||
| xaxis: { | ||
| title: xAxis.key, | ||
| }, | ||
| yaxis: { | ||
| title: yAxis.key, | ||
| }, | ||
| }} | ||
| config={{ | ||
| responsive: true, | ||
| displaylogo: false, | ||
| scrollZoom: true, | ||
| modeBarButtonsToRemove: [ | ||
| 'sendDataToCloud', | ||
| 'select2d', | ||
| 'lasso2d', | ||
| 'resetScale2d', | ||
| 'hoverClosestCartesian', | ||
| 'hoverCompareCartesian', | ||
| ], | ||
| }} | ||
| useResizeHandler | ||
| /> | ||
| ); | ||
| }; | ||
|
|
||
| return ( | ||
| <Row> | ||
| <Col span={6}> | ||
| <div> | ||
| <label htmlFor='x-axis-selector'> | ||
| <FormattedMessage | ||
| defaultMessage='X-axis:' | ||
| description='Label text for X-axis in box plot comparison in MLflow' | ||
| /> | ||
| </label> | ||
| </div> | ||
| {renderSelector(handleXAxisChange, xAxis.value)} | ||
|
|
||
| <div> | ||
| <label htmlFor='y-axis-selector'> | ||
| <FormattedMessage | ||
| defaultMessage='Y-axis:' | ||
| description='Label text for Y-axis in box plot comparison in MLflow' | ||
| /> | ||
| </label> | ||
| </div> | ||
| {renderSelector(handleYAxisChange, yAxis.value)} | ||
| </Col> | ||
| <Col span={18}>{renderPlot()}</Col> | ||
| </Row> | ||
| ); | ||
| }; | ||
|
|
||
| CompareRunBox.propTypes = { | ||
| runUuids: PropTypes.arrayOf(PropTypes.string).isRequired, | ||
| runInfos: PropTypes.arrayOf(PropTypes.instanceOf(RunInfo)).isRequired, | ||
| metricLists: PropTypes.arrayOf(PropTypes.arrayOf(PropTypes.object)).isRequired, | ||
| paramLists: PropTypes.arrayOf(PropTypes.arrayOf(PropTypes.object)).isRequired, | ||
| }; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import React from 'react'; | ||
| import { Select } from 'antd'; | ||
| import { CompareRunBox } from './CompareRunBox'; | ||
| import { RunInfo } from '../sdk/MlflowMessages'; | ||
| import { mountWithIntl } from '../../common/utils/TestUtils'; | ||
| import { LazyPlot } from './LazyPlot'; | ||
|
|
||
| describe('CompareRunBox', () => { | ||
| let wrapper; | ||
|
|
||
| const runUuids = ['1', '2', '3']; | ||
| const commonProps = { | ||
| runUuids, | ||
| runInfos: runUuids.map((run_uuid) => | ||
| RunInfo.fromJs({ | ||
| run_uuid, | ||
| experiment_id: '0', | ||
| }), | ||
| ), | ||
| runDisplayNames: runUuids, | ||
| }; | ||
|
|
||
| test('should render with minimal props without exploding', () => { | ||
| const props = { | ||
| ...commonProps, | ||
| paramLists: [ | ||
| [{ key: 'param', value: 1 }], | ||
| [{ key: 'param', value: 2 }], | ||
| [{ key: 'param', value: 3 }], | ||
| ], | ||
| metricLists: [ | ||
| [{ key: 'metric', value: 4 }], | ||
| [{ key: 'metric', value: 5 }], | ||
| [{ key: 'metric', value: 6 }], | ||
| ], | ||
| }; | ||
|
|
||
| wrapper = mountWithIntl(<CompareRunBox {...props} />); | ||
| expect(wrapper.find(LazyPlot).isEmpty()).toBe(true); | ||
| expect(wrapper.text()).toContain('Select parameters/metrics to plot.'); | ||
|
|
||
| const selectors = wrapper.find(Select); | ||
| expect(selectors.length).toBe(2); | ||
| // Set x-axis to 'param' | ||
| const xAxisSelector = selectors.at(0); | ||
| xAxisSelector.find('input[type="search"]').simulate('mouseDown'); | ||
| // `wrapper.find` can't find the selector options because they appear in the top level of the | ||
| // document. | ||
| document.querySelectorAll('[data-test-id="axis-option"]')[0].click(); | ||
| expect(xAxisSelector.text()).toContain('param'); | ||
| // Set y-axis to 'metric' | ||
| const yAxisSelector = selectors.at(1); | ||
| yAxisSelector.find('input[type="search"]').simulate('mouseDown'); | ||
| document.querySelectorAll('[data-test-id="axis-option"]')[3].click(); | ||
| expect(yAxisSelector.text()).toContain('metric'); | ||
| wrapper.update(); | ||
| expect(wrapper.find(LazyPlot).exists()).toBe(true); | ||
| expect(wrapper.text()).not.toContain('Select parameters/metrics to plot.'); | ||
| }); | ||
| }); |
Uh oh!
There was an error while loading. Please reload this page.