dtorchtree.models
Description
This package contains pretrained and trainable models for image classification and generation.
Here’s an example of how to use the pretrained models.
import dtorch
from dtorchtree.models import IrisDecisionTree
# Load the pretrained model.
model = IrisDecisionTree()
# Predict the class for the input sample.
X = dtorch.tensor([5.1, 3.5, 1.4, 0.2])
y = model.predict(X)
print(y) # "Setosa"
Models
- class DecisionTree
Decision tree for classification.
- __init__(criterion: str = 'gini', max_depth: int | None = None, multiprocesses_threshold: int | None = 0.4)
- Parameters:
criterion – The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain.
max_depth – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure.
multiprocesses_threshold – The threshold of impurity to use multiprocesses. If the impurity is lower than this threshold, the node is expanded in a single process.
- fit(X: torch.Tensor, y: torch.Tensor)
- Parameters:
X – The training input samples. (n_samples, n_features)
y – The target values (class labels). (n_samples,)
Build a decision tree classifier from the training set (X, y).
- predict(X: torch.Tensor) torch.Tensor
- Parameters:
X – The input samples. (n_features,)
- Returns:
The predicted class.
Predict the class for the input sample.
- save(path: str)
- Parameters:
path – The path to save the model.
Save the model to the path.
- load(path: str)
- Parameters:
path – The path to load the model.
Load the model from the path.
Pretrained Models
- class IrisDecisionTree(JPretrainedModel)
- __init__(self, root: str = './models')
- Parameters:
root – The root directory of the pretrained model.
Load the pretrained model.
- predict(self, X: torch.Tensor) torch.Tensor
- Parameters:
X – The input samples. (n_samples, n_features)
- Returns:
The predicted class.
Predict the class for the input sample.