[GSK-1343] Better support for CUDA and models + docs improvements#1234
[GSK-1343] Better support for CUDA and models + docs improvements#1234andreybavt merged 55 commits intomainfrom
Conversation
GSK-1343 Bug running HF model on GPU
When running model predictions, we do: but we should first move the tensors to CPU or it will fail with |
|
There was a problem hiding this comment.
- the
batch_sizeto be saved and reloaded properly inWrapperModel - good catch the catboost missing
load_model, I remember vividly to have fixed it, but maybe in a branch that has never been merged. - good job on the tests
- good job on the doc
- Overall I like the new refactoring of models, the previous modular structure we followed only makes sense if we didn't have classes, only collection of methods (such as mlflow)
| ================== | ||
|
|
||
| .. automodule:: giskard.models.base | ||
| :members: BaseModel, WrapperModel, MLFlowSerializableModel, CloudpickleSerializableModel |
There was a problem hiding this comment.
Could we rewrite this in a way that shows these 4 classes as a drop-down list when we click on "Base model classes" on the left toc menu?
There was a problem hiding this comment.
yes, the left sidebar - not really urgent this one. I admittedly tried to do it myself, wasn't super intuitive without some files re-organization.
|
A not-so-urgent suggestion: to avoid devs from forgetting the saving of daughter model class' meta such as |
The problem with that would be deciding which attributes should be checked. I’m not sure that all model attributes would need to be persisted and reloaded. |
|
hmm... Why not? which attributes you have in mind? class BaseModel(ABC):
@configured_validate_arguments
def __init__(
self,
model_type: ModelType,
name: Optional[str] = None,
feature_names: Optional[Iterable] = None,
classification_threshold: Optional[float] = 0.5,
classification_labels: Optional[Iterable] = None,
**kwargs,
) :
...
@classmethod
def load(cls, local_dir, **kwargs):
...all attributes should be checked no? |
|
Kudos, SonarCloud Quality Gate passed! |








WrapperModelCatboostModelthat preventedcatboostmodels from being loaded correctlymodelspackage and revised all docstringsCatboostModel,WrapperModel,BaseModel,CloudpickleSerializableModeltest_catboost_changed_column_orderandtest_tabular_titanic_binary_classification)TODO