Skip to content

[GSK-1343] Better support for CUDA and models + docs improvements#1234

Merged
andreybavt merged 55 commits intomainfrom
task/GSK-1343
Jul 24, 2023
Merged

[GSK-1343] Better support for CUDA and models + docs improvements#1234
andreybavt merged 55 commits intomainfrom
task/GSK-1343

Conversation

@mattbit
Copy link
Copy Markdown
Member

@mattbit mattbit commented Jul 6, 2023

  • Added support for batching in WrapperModel
  • Fixed a bug in CatboostModel that prevented catboost models from being loaded correctly
  • Reorganized the models package and revised all docstrings
  • Added tests for CatboostModel, WrapperModel, BaseModel, CloudpickleSerializableModel
  • Re-enabled skipped tests that now can work! (test_catboost_changed_column_order and test_tabular_titanic_binary_classification)
  • Added all model wrappers to the docs, under “API Reference”

TODO

  • write test for batching
  • manual test with pytorch on CUDA
  • manual test with tensorflow on CUDA

@linear
Copy link
Copy Markdown

linear Bot commented Jul 6, 2023

GSK-1343 Bug running HF model on GPU

When running model predictions, we do:

logits = predictions.logits.detach().numpy()

but we should first move the tensors to CPU or it will fail with

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

@mattbit mattbit changed the title Better support for CUDA and docs improvements [GSK-1343] Better support for CUDA and docs improvements Jul 6, 2023
@mattbit mattbit self-assigned this Jul 6, 2023
@andreybavt
Copy link
Copy Markdown
Contributor

@mattbit

  1. Sure, covering download with tests is a good idea
  2. In case you want to create a custom model you'd inherit a CustomModel class in this case you'd use a load from BaseModel

@mattbit mattbit requested a review from rabah-khalek July 12, 2023 14:50
Copy link
Copy Markdown
Contributor

@rabah-khalek rabah-khalek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • the batch_size to be saved and reloaded properly in WrapperModel
  • good catch the catboost missing load_model, I remember vividly to have fixed it, but maybe in a branch that has never been merged.
  • good job on the tests
  • good job on the doc
  • Overall I like the new refactoring of models, the previous modular structure we followed only makes sense if we didn't have classes, only collection of methods (such as mlflow)

==================

.. automodule:: giskard.models.base
:members: BaseModel, WrapperModel, MLFlowSerializableModel, CloudpickleSerializableModel
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rewrite this in a way that shows these 4 classes as a drop-down list when we click on "Base model classes" on the left toc menu?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the sidebar?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the left sidebar - not really urgent this one. I admittedly tried to do it myself, wasn't super intuitive without some files re-organization.

Comment thread python-client/giskard/models/base/wrapper.py
Comment thread python-client/giskard/models/base/wrapper.py
@rabah-khalek
Copy link
Copy Markdown
Contributor

A not-so-urgent suggestion: to avoid devs from forgetting the saving of daughter model class' meta such as batch_size in your case, would be wise to extend model_validation in https://github.com/Giskard-AI/giskard/blob/c860a43a7ba1b7a5c6e11ddeafe7047d9f82409a/python-client/giskard/core/model_validation.py#L126 to compare the attributes values (as much as possible) before and after the saving, with an associated test

@mattbit
Copy link
Copy Markdown
Member Author

mattbit commented Jul 12, 2023

A not-so-urgent suggestion: to avoid devs from forgetting the saving of daughter model class' meta such as batch_size in your case, would be wise to extend model_validation in

https://github.com/Giskard-AI/giskard/blob/c860a43a7ba1b7a5c6e11ddeafe7047d9f82409a/python-client/giskard/core/model_validation.py#L126
to compare the attributes values (as much as possible) before and after the saving, with an associated test

The problem with that would be deciding which attributes should be checked. I’m not sure that all model attributes would need to be persisted and reloaded.

@rabah-khalek
Copy link
Copy Markdown
Contributor

rabah-khalek commented Jul 12, 2023

hmm... Why not? which attributes you have in mind?
I think outside kwargs that are passed externally upon loading (and are not attributes anyways):

class BaseModel(ABC):
    @configured_validate_arguments
    def __init__(
        self,
        model_type: ModelType,
        name: Optional[str] = None,
        feature_names: Optional[Iterable] = None,
        classification_threshold: Optional[float] = 0.5,
        classification_labels: Optional[Iterable] = None,
        **kwargs,
    ) :
    ...

    @classmethod
    def load(cls, local_dir, **kwargs):
    ...

all attributes should be checked no?

@mattbit mattbit requested a review from rabah-khalek July 21, 2023 14:51
Copy link
Copy Markdown
Contributor

@rabah-khalek rabah-khalek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @mattbit, I just did some minor refactoring in a70b89a. To me, it's good to go.

Comment thread python-client/giskard/models/base/wrapper.py Outdated
@mattbit mattbit requested a review from andreybavt July 24, 2023 10:27
@andreybavt andreybavt merged commit ff7ba18 into main Jul 24, 2023
@sonarqubecloud
Copy link
Copy Markdown

Kudos, SonarCloud Quality Gate passed!    Quality Gate passed

Bug A 0 Bugs
Vulnerability A 0 Vulnerabilities
Security Hotspot A 0 Security Hotspots
Code Smell A 9 Code Smells

87.8% 87.8% Coverage
0.0% 0.0% Duplication

@mattbit mattbit deleted the task/GSK-1343 branch July 24, 2023 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

3 participants