Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions giskard/models/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import pandas as pd

from ..core.core import ModelType, SupportedModelTypes
from .base.serialization import CloudpickleSerializableModel
from .function import PredictionFunctionModel
from ..core.core import ModelType, SupportedModelTypes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,8 +128,8 @@ def __new__(
giskard_cls = CloudpickleSerializableModel
# if save_model and load_model are overriden, replace them, if not, these equalities will be identities.
possibly_overriden_cls = cls
possibly_overriden_cls.save_model = giskard_cls.save_model
possibly_overriden_cls.load_model = giskard_cls.load_model
# possibly_overriden_cls.save_model = giskard_cls.save_model
# possibly_overriden_cls.load_model = giskard_cls.load_model
Comment thread
rabah-khalek marked this conversation as resolved.
Outdated
possibly_overriden_cls.should_save_model_class = True
elif giskard_cls:
input_type = "'prediction_function'" if giskard_cls == PredictionFunctionModel else "'model'"
Expand Down
34 changes: 34 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import numpy as np
import pandas as pd

from giskard import Model
from giskard.core.core import SupportedModelTypes
Expand Down Expand Up @@ -57,3 +58,36 @@ def prediction_fn(df):

assert np.all(np.equal(predictions.raw, second_predictions.raw))
assert nb_of_prediction_calls[0] == 1


def test_model_save_and_load_not_overriden():
def model_fn(df):
return [True] * len(df)

call_count = dict({"save": 0, "load": 0})

class MyCustomModel(Model):
def save_model(self, path):
call_count["save"] = call_count["save"] + 1
Path(path).joinpath("custom_data").touch()

@classmethod
def load_model(cls, path, **kwargs):
call_count["load"] = call_count["load"] + 1

def model(x):
return [True] * len(x)

return model

def model_predict(self, df: pd.DataFrame):
return self.model(df)

with tempfile.TemporaryDirectory() as tmpdirname:
gsk_model = MyCustomModel(model_fn, model_type="regression")

gsk_model.save(tmpdirname)
assert call_count["save"] == 1

MyCustomModel.load(tmpdirname)
assert call_count["load"] == 1