Skip to content

Commit 1d6df7d

Browse files
Merge pull request #1175 from Giskard-AI/feature/gsk-1280-slicing-functions-are-being-shared-between-all-projects
GSK-1272- Filter slicing functions clauses when column name are not available
2 parents 9955229 + c1321ba commit 1d6df7d

28 files changed

Lines changed: 313 additions & 202 deletions

backend/src/main/java/ai/giskard/domain/DatasetProcessFunction.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,21 @@
99

1010
import java.util.List;
1111
import java.util.Map;
12+
import java.util.Set;
1213

1314
@Getter
1415
@Setter
1516
@MappedSuperclass
1617
public abstract class DatasetProcessFunction extends Callable {
17-
@Column
18-
private String projectKey;
18+
19+
@ManyToMany
20+
@JoinTable(
21+
name = "dataset_process_function_projects",
22+
joinColumns = {@JoinColumn(name = "function_uuid")},
23+
inverseJoinColumns = {@JoinColumn(name = "project_id")}
24+
)
25+
private Set<Project> projects;
26+
1927
@Column(nullable = false)
2028
@ColumnDefault("false")
2129
private boolean cellLevel;

backend/src/main/java/ai/giskard/repository/ProjectRepository.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import org.springframework.data.jpa.repository.EntityGraph;
99
import org.springframework.stereotype.Repository;
1010

11+
import java.util.Collection;
1112
import java.util.List;
1213
import java.util.Optional;
14+
import java.util.Set;
1315

1416
import static ai.giskard.web.rest.errors.EntityNotFoundException.By;
1517

@@ -47,6 +49,8 @@ default Project getOneByName(String name) {
4749

4850
Optional<Project> findOneByKey(String key);
4951

52+
Set<Project> findAllByKeyIn(Collection<String> keys);
53+
5054
default Project getOneByKey(String key) {
5155
return findOneByKey(key).orElseThrow(() -> new EntityNotFoundException(Entity.PROJECT, By.KEY, key));
5256
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package ai.giskard.repository.ml;
2+
3+
import ai.giskard.domain.DatasetProcessFunction;
4+
import ai.giskard.domain.Project;
5+
import org.springframework.data.jpa.repository.Query;
6+
import org.springframework.data.repository.NoRepositoryBean;
7+
8+
import java.util.List;
9+
10+
@NoRepositoryBean
11+
public interface DatasetProcessFunctionRepository<E extends DatasetProcessFunction> extends CallableRepository<E> {
12+
13+
@Query("SELECT e FROM #{#entityName} e LEFT JOIN e.projects p WHERE p IS NULL OR p = :project")
14+
List<E> findAllForProject(Project project);
15+
}

backend/src/main/java/ai/giskard/repository/ml/SlicingFunctionRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import org.springframework.stereotype.Repository;
66

77
@Repository
8-
public interface SlicingFunctionRepository extends CallableRepository<SlicingFunction> {
8+
public interface SlicingFunctionRepository extends DatasetProcessFunctionRepository<SlicingFunction> {
99
@Override
1010
default Entity getEntityType() {
1111
return Entity.SLICING_FUNCTION;

backend/src/main/java/ai/giskard/repository/ml/TransformationFunctionRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import org.springframework.stereotype.Repository;
66

77
@Repository
8-
public interface TransformationFunctionRepository extends CallableRepository<TransformationFunction> {
8+
public interface TransformationFunctionRepository extends DatasetProcessFunctionRepository<TransformationFunction> {
99

1010
@Override
1111
default Entity getEntityType() {
Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,32 @@
11
package ai.giskard.service;
22

33
import ai.giskard.domain.DatasetProcessFunction;
4+
import ai.giskard.repository.ProjectRepository;
45
import ai.giskard.repository.ml.CallableRepository;
56
import ai.giskard.web.dto.DatasetProcessFunctionDTO;
67
import ai.giskard.web.dto.mapper.GiskardMapper;
8+
import org.hibernate.Hibernate;
9+
import org.springframework.transaction.annotation.Transactional;
10+
11+
import java.util.UUID;
712

813
public abstract class DatasetProcessFunctionService<E extends DatasetProcessFunction, D extends DatasetProcessFunctionDTO> extends CallableService<E, D> {
914

10-
public DatasetProcessFunctionService(CallableRepository<E> callableRepository, GiskardMapper giskardMapper) {
15+
protected ProjectRepository projectRepository;
16+
17+
protected DatasetProcessFunctionService(CallableRepository<E> callableRepository,
18+
GiskardMapper giskardMapper,
19+
ProjectRepository projectRepository) {
1120
super(callableRepository, giskardMapper);
21+
this.projectRepository = projectRepository;
22+
}
23+
24+
@Transactional(readOnly = true)
25+
@Override
26+
public E getInitialized(UUID uuid) {
27+
E callable = super.getInitialized(uuid);
28+
Hibernate.initialize(callable.getProjects());
29+
return callable;
1230
}
1331

1432
protected E update(E existingCallable, D dto) {
@@ -17,6 +35,8 @@ protected E update(E existingCallable, D dto) {
1735
existingCallable.setColumnType(dto.getColumnType());
1836
existingCallable.setProcessType(dto.getProcessType());
1937
existingCallable.setClauses(dto.getClauses());
38+
existingCallable.getProjects().addAll(projectRepository.findAllByKeyIn(dto.getProjectKeys()));
2039
return existingCallable;
2140
}
41+
2242
}

backend/src/main/java/ai/giskard/service/SlicingFunctionService.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai.giskard.service;
22

33
import ai.giskard.domain.SlicingFunction;
4+
import ai.giskard.repository.ProjectRepository;
45
import ai.giskard.repository.ml.SlicingFunctionRepository;
56
import ai.giskard.utils.UUID5;
67
import ai.giskard.web.dto.ComparisonClauseDTO;
@@ -15,6 +16,7 @@
1516
import java.util.Collections;
1617
import java.util.List;
1718
import java.util.Map;
19+
import java.util.Set;
1820
import java.util.stream.Collectors;
1921

2022
import static ai.giskard.utils.GiskardStringUtils.escapePythonVariable;
@@ -26,8 +28,10 @@ public class SlicingFunctionService extends DatasetProcessFunctionService<Slicin
2628

2729
private final SlicingFunctionRepository slicingFunctionRepository;
2830

29-
public SlicingFunctionService(SlicingFunctionRepository slicingFunctionRepository, GiskardMapper giskardMapper) {
30-
super(slicingFunctionRepository, giskardMapper);
31+
public SlicingFunctionService(SlicingFunctionRepository slicingFunctionRepository,
32+
GiskardMapper giskardMapper,
33+
ProjectRepository projectRepository) {
34+
super(slicingFunctionRepository, giskardMapper, projectRepository);
3135
this.slicingFunctionRepository = slicingFunctionRepository;
3236
}
3337

@@ -41,7 +45,7 @@ public SlicingFunctionDTO save(SlicingFunctionDTO slicingFunction) {
4145

4246
protected SlicingFunction create(SlicingFunctionDTO dto) {
4347
SlicingFunction function = giskardMapper.fromDTO(dto);
44-
function.setProjectKey(dto.getProjectKey());
48+
function.setProjects(projectRepository.findAllByKeyIn(dto.getProjectKeys()));
4549
initializeCallable(function);
4650
return function;
4751
}
@@ -65,7 +69,7 @@ private Map<String, Object> toCode(ComparisonClauseDTO clause) {
6569
);
6670
}
6771

68-
public SlicingFunctionDTO generate(List<ComparisonClauseDTO> comparisonClauses) throws JsonProcessingException {
72+
public SlicingFunctionDTO generate(List<ComparisonClauseDTO> comparisonClauses, String projectKey) throws JsonProcessingException {
6973
String name = comparisonClauses.stream().map(this::clauseToString).collect(Collectors.joining(" & "));
7074

7175
SlicingFunction slicingFunction = new SlicingFunction();
@@ -83,6 +87,7 @@ public SlicingFunctionDTO generate(List<ComparisonClauseDTO> comparisonClauses)
8387
slicingFunction.setCellLevel(false);
8488
slicingFunction.setColumnType("");
8589
slicingFunction.setProcessType(DatasetProcessFunctionType.CLAUSES);
90+
slicingFunction.setProjects(Set.of(projectRepository.getOneByKey(projectKey)));
8691

8792
return giskardMapper.toDTO(slicingFunctionRepository.save(slicingFunction));
8893
}

backend/src/main/java/ai/giskard/service/TransformationFunctionService.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai.giskard.service;
22

33
import ai.giskard.domain.TransformationFunction;
4+
import ai.giskard.repository.ProjectRepository;
45
import ai.giskard.repository.ml.TransformationFunctionRepository;
56
import ai.giskard.web.dto.TransformationFunctionDTO;
67
import ai.giskard.web.dto.mapper.GiskardMapper;
@@ -12,8 +13,10 @@ public class TransformationFunctionService extends DatasetProcessFunctionService
1213

1314
private final TransformationFunctionRepository transformationFunctionRepository;
1415

15-
public TransformationFunctionService(TransformationFunctionRepository transformationFunctionRepository, GiskardMapper giskardMapper) {
16-
super(transformationFunctionRepository, giskardMapper);
16+
public TransformationFunctionService(TransformationFunctionRepository transformationFunctionRepository,
17+
GiskardMapper giskardMapper,
18+
ProjectRepository projectRepository) {
19+
super(transformationFunctionRepository, giskardMapper, projectRepository);
1720
this.transformationFunctionRepository = transformationFunctionRepository;
1821
}
1922

backend/src/main/java/ai/giskard/service/ml/MLWorkerCacheService.java

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ai.giskard.service.ml;
22

3+
import ai.giskard.domain.Project;
34
import ai.giskard.ml.MLWorkerID;
45
import ai.giskard.ml.MLWorkerWSAction;
56
import ai.giskard.ml.dto.MLWorkerWSCatalogDTO;
@@ -21,7 +22,6 @@
2122
import org.springframework.transaction.annotation.Transactional;
2223

2324
import java.util.UUID;
24-
import java.util.stream.Stream;
2525

2626
import static ai.giskard.ml.dto.MLWorkerWSUtils.convertMLWorkerWSObject;
2727

@@ -39,44 +39,32 @@ public class MLWorkerCacheService {
3939
private final TransformationFunctionRepository transformationFunctionRepository;
4040
private final ProjectRepository projectRepository;
4141
private final GiskardMapper giskardMapper;
42-
private CatalogDTO catalogWithoutPickles = new CatalogDTO();
4342

4443
@Transactional
4544
public CatalogDTO getCatalog(long projectId) {
4645
// TODO: Remove from transaction, however it mostly relly on cache so impact is reduced
47-
CatalogDTO catalog = findGiskardTest(projectRepository.getMandatoryById(projectId).isUsingInternalWorker());
46+
47+
Project project = projectRepository.getMandatoryById(projectId);
48+
findGiskardTest(project.isUsingInternalWorker());
4849

4950
return CatalogDTO.builder()
50-
.tests(Stream.concat(
51-
testFunctionRepository.findAll().stream().map(giskardMapper::toDTO),
52-
catalog.getTests().stream()
53-
)
54-
.toList())
55-
.slices(Stream.concat(
56-
slicingFunctionRepository.findAll().stream().map(giskardMapper::toDTO),
57-
catalog.getSlices().stream()
58-
)
59-
.toList())
60-
.transformations(Stream.concat(
61-
transformationFunctionRepository.findAll().stream().map(giskardMapper::toDTO),
62-
catalog.getTransformations().stream()
63-
)
64-
.toList())
51+
.tests(testFunctionRepository.findAll().stream().map(giskardMapper::toDTO).toList())
52+
.slices(slicingFunctionRepository.findAllForProject(project).stream()
53+
.map(giskardMapper::toDTO).toList())
54+
.transformations(transformationFunctionRepository.findAllForProject(project).stream()
55+
.map(giskardMapper::toDTO).toList())
6556
.build();
6657
}
6758

68-
public CatalogDTO findGiskardTest(boolean isInternal) {
59+
public void findGiskardTest(boolean isInternal) {
6960
if (isInternal) {
70-
// Only cache external ML worker
71-
return getTestFunctions(true);
61+
return;
7262
}
7363

74-
catalogWithoutPickles = getTestFunctions(false);
64+
CatalogDTO catalogWithoutPickles = getTestFunctions(false);
7565
testFunctionService.saveAll(catalogWithoutPickles.getTests());
7666
slicingFunctionService.saveAll(catalogWithoutPickles.getSlices());
7767
transformationFunctionService.saveAll(catalogWithoutPickles.getTransformations());
78-
79-
return catalogWithoutPickles;
8068
}
8169

8270
private CatalogDTO getTestFunctions(boolean isInternal) {

backend/src/main/java/ai/giskard/utils/TransactionUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ public static void initializeCallable(Callable callable) {
2828
Hibernate.initialize(callable.getArgs());
2929
}
3030

31+
3132
}

0 commit comments

Comments
 (0)