Skip to content

fix: Correct plotting of trees with more than 1 output#2668

Closed
tincher wants to merge 1 commit into
catboost:masterfrom
tincher:fix-tree_plotting
Closed

fix: Correct plotting of trees with more than 1 output#2668
tincher wants to merge 1 commit into
catboost:masterfrom
tincher:fix-tree_plotting

Conversation

@tincher
Copy link
Copy Markdown

@tincher tincher commented May 21, 2024

Before submitting a pull request, please do the following steps:

I hereby agree to the terms of the CLA available at: https://yandex.ru/legal/cla/?lang=en.

  1. Read instructions for contributors.
  2. Make sure the code builds.
  3. If you add new functionality add tests to check it.
  4. Run existing tests to make sure you haven't broken anything.
  5. If you haven't already, sign the Contributor License Agreement.

#------

If a tree is trained with multiple outputs the plot_tree doesnt work for all trees because the indeces are too big.
I tested this with multiple outputs and multiple numbers of outputs.
I attached a minimal example which crashes:

import numpy as np
from catboost import CatBoostClassifier, Pool
from catboost.datasets import titanic

titanic_df = titanic()

y = titanic_df[0][["Survived", "Sex"]]
y.loc[:, "Sex"] = y.loc[:, "Sex"].map({"male": 1, "female": 0})
X = titanic_df[0].drop(["Survived", "Sex"], axis=1)

is_cat = X.dtypes != float
for feature, feat_is_cat in is_cat.to_dict().items():
    if feat_is_cat:
        X[feature].fillna("NAN", inplace=True)

cat_features_index = np.where(is_cat)[0]
pool = Pool(X, y, cat_features=cat_features_index, feature_names=list(X.columns))

parameters = {
    "iterations": 10,
    "depth": 3,
    "grow_policy": "Depthwise",
    "loss_function": "MultiLogloss",
    "random_seed": 0,
}
model = CatBoostClassifier(**parameters)
model.fit(pool)
model.plot_tree(0)

robot-piglet pushed a commit that referenced this pull request Aug 20, 2024
…ultidimensional approx with non-oblivious trees: #2668).

e47255eea952cef26d1cce2b8a960ad0bf3af6f8
@andrey-khropov
Copy link
Copy Markdown
Member

Thank you for the bug report. I've fixed the bug in the more general case (includes MultiRegression as well) in b45e7a5. The fix will be included in the next release.

@tincher tincher deleted the fix-tree_plotting branch August 31, 2024 05:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants