|
| 1 | +from typing import Dict, Hashable, List, Optional, Union |
| 2 | + |
1 | 3 | import inspect |
2 | 4 | import logging |
3 | 5 | import posixpath |
4 | 6 | import tempfile |
5 | 7 | import uuid |
6 | 8 | from functools import cached_property |
7 | 9 | from pathlib import Path |
8 | | -from typing import Dict, Optional, List, Union, Hashable |
9 | 10 |
|
10 | 11 | import numpy as np |
11 | 12 | import pandas |
12 | 13 | import pandas as pd |
13 | 14 | import yaml |
14 | | -from pandas.api.types import is_list_like |
15 | | -from pandas.api.types import is_numeric_dtype |
| 15 | +from mlflow import MlflowClient |
| 16 | +from pandas.api.types import is_list_like, is_numeric_dtype |
16 | 17 | from xxhash import xxh3_128_hexdigest |
17 | 18 | from zstandard import ZstdDecompressor |
18 | | -from mlflow import MlflowClient |
19 | 19 |
|
20 | 20 | from giskard.client.giskard_client import GiskardClient |
21 | | -from giskard.client.io_utils import save_df, compress |
| 21 | +from giskard.client.io_utils import compress, save_df |
22 | 22 | from giskard.client.python_utils import warning |
23 | 23 | from giskard.core.core import DatasetMeta, SupportedColumnTypes |
24 | 24 | from giskard.core.validation import configured_validate_arguments |
25 | | -from giskard.ml_worker.testing.registry.slicing_function import ( |
26 | | - SlicingFunction, |
27 | | - SlicingFunctionType, |
28 | | -) |
| 25 | +from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType |
29 | 26 | from giskard.ml_worker.testing.registry.transformation_function import ( |
30 | 27 | TransformationFunction, |
31 | 28 | TransformationFunctionType, |
32 | 29 | ) |
33 | 30 | from giskard.settings import settings |
34 | | -from ..metadata.indexing import ColumnMetadataMixin |
| 31 | + |
35 | 32 | from ...ml_worker.utils.file_utils import get_file_name |
| 33 | +from ..metadata.indexing import ColumnMetadataMixin |
36 | 34 |
|
37 | 35 | SAMPLE_SIZE = 1000 |
38 | 36 |
|
@@ -521,7 +519,7 @@ def load(cls, local_path: str): |
521 | 519 | ) |
522 | 520 |
|
523 | 521 | @classmethod |
524 | | - def download(cls, client: GiskardClient, project_key, dataset_id, sample: bool = False): |
| 522 | + def download(cls, client: Optional[GiskardClient], project_key, dataset_id, sample: bool = False): |
525 | 523 | """ |
526 | 524 | Downloads a dataset from a Giskard project and returns a Dataset object. |
527 | 525 | If the client is None, then the function assumes that it is running in an internal worker and looks for the dataset locally. |
@@ -661,9 +659,7 @@ def to_mlflow(self, mlflow_client: MlflowClient = None, mlflow_run_id: str = Non |
661 | 659 |
|
662 | 660 | # To avoid file being open in write mode and read at the same time, |
663 | 661 | # First, we'll write it, then make sure to remove it |
664 | | - with tempfile.NamedTemporaryFile( |
665 | | - prefix="dataset-", suffix=".csv", delete=False |
666 | | - ) as f: |
| 662 | + with tempfile.NamedTemporaryFile(prefix="dataset-", suffix=".csv", delete=False) as f: |
667 | 663 | # Get file path |
668 | 664 | local_path = f.name |
669 | 665 | # Get name from file |
@@ -696,8 +692,10 @@ def to_wandb(self, **kwargs) -> None: |
696 | 692 | Additional keyword arguments |
697 | 693 | (see https://docs.wandb.ai/ref/python/init) to be added to the active WandB run. |
698 | 694 | """ |
699 | | - from giskard.integrations.wandb.wandb_utils import wandb_run |
700 | 695 | import wandb # noqa library import already checked in wandb_run |
| 696 | + |
| 697 | + from giskard.integrations.wandb.wandb_utils import wandb_run |
| 698 | + |
701 | 699 | from ...utils.analytics_collector import analytics |
702 | 700 |
|
703 | 701 | with wandb_run(**kwargs) as run: |
|
0 commit comments