dtorchtree.models

Description

This package contains pretrained and trainable models for image classification and generation.

Here’s an example of how to use the pretrained models.

import dtorch
from dtorchtree.models import IrisDecisionTree

# Load the pretrained model.
model = IrisDecisionTree()

# Predict the class for the input sample.
X = dtorch.tensor([5.1, 3.5, 1.4, 0.2])
y = model.predict(X)
print(y) # "Setosa"

Models

class DecisionTree

Decision tree for classification.

__init__(criterion: str = 'gini', max_depth: int | None = None, multiprocesses_threshold: int | None = 0.4)
Parameters:
  • criterion – The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain.

  • max_depth – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure.

  • multiprocesses_threshold – The threshold of impurity to use multiprocesses. If the impurity is lower than this threshold, the node is expanded in a single process.

fit(X: torch.Tensor, y: torch.Tensor)
Parameters:
  • X – The training input samples. (n_samples, n_features)

  • y – The target values (class labels). (n_samples,)

Build a decision tree classifier from the training set (X, y).

predict(X: torch.Tensor) torch.Tensor
Parameters:

X – The input samples. (n_features,)

Returns:

The predicted class.

Predict the class for the input sample.

save(path: str)
Parameters:

path – The path to save the model.

Save the model to the path.

load(path: str)
Parameters:

path – The path to load the model.

Load the model from the path.

Pretrained Models

class IrisDecisionTree(JPretrainedModel)
__init__(self, root: str = './models')
Parameters:

root – The root directory of the pretrained model.

Load the pretrained model.

predict(self, X: torch.Tensor) torch.Tensor
Parameters:

X – The input samples. (n_samples, n_features)

Returns:

The predicted class.

Predict the class for the input sample.