Skip to content

Commit e17ab7f

Browse files
committed
Revert "Use workers type as-is" to test internal worker
This reverts commit 0cbd8ae.
1 parent 7d6bcff commit e17ab7f

7 files changed

Lines changed: 35 additions & 1 deletion

File tree

backend/src/main/java/ai/giskard/service/ModelService.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ public class ModelService {
4444

4545
public MLWorkerWSRunModelForDataFrameDTO predict(ProjectModel model, Dataset dataset, Map<String, String> features) {
4646
MLWorkerID workerID = model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
47+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
48+
workerID = MLWorkerID.INTERNAL;
49+
}
4750
if (mlWorkerWSService.isWorkerConnected(workerID)) {
4851
return getRunModelForDataFrameResponse(model, dataset, features);
4952
}
@@ -81,6 +84,9 @@ private MLWorkerWSRunModelForDataFrameDTO getRunModelForDataFrameResponse(Projec
8184
}
8285

8386
MLWorkerID workerID = model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
87+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
88+
workerID = MLWorkerID.INTERNAL;
89+
}
8490
// Perform the runModelForDataFrame action and parse the reply
8591
MLWorkerWSBaseDTO result = mlWorkerWSCommService.performAction(
8692
workerID,
@@ -102,6 +108,9 @@ public boolean shouldDrop(String columnDtype, String value) {
102108

103109
public MLWorkerWSExplainDTO explain(ProjectModel model, Dataset dataset, Map<String, String> features) {
104110
MLWorkerID workerID = model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
111+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
112+
workerID = MLWorkerID.INTERNAL;
113+
}
105114
if (mlWorkerWSService.isWorkerConnected(workerID)) {
106115
MLWorkerWSExplainParamDTO param = MLWorkerWSExplainParamDTO.builder()
107116
.model(MLWorkerWSArtifactRefDTO.fromModel(model))
@@ -126,6 +135,9 @@ public MLWorkerWSExplainDTO explain(ProjectModel model, Dataset dataset, Map<Str
126135

127136
public MLWorkerWSExplainTextDTO explainText(ProjectModel model, Dataset dataset, String featureName, Map<String, String> features) {
128137
MLWorkerID workerID = model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
138+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
139+
workerID = MLWorkerID.INTERNAL;
140+
}
129141
if (mlWorkerWSService.isWorkerConnected(workerID)) {
130142
MLWorkerWSExplainTextParamDTO param = MLWorkerWSExplainTextParamDTO.builder()
131143
.model(MLWorkerWSArtifactRefDTO.fromModel(model))
@@ -174,6 +186,9 @@ public Inspection createInspection(String name, UUID modelId, UUID datasetId, bo
174186

175187
protected void predictSerializedDataset(ProjectModel model, Dataset dataset, Long inspectionId, boolean sample) {
176188
MLWorkerID workerID = model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
189+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
190+
workerID = MLWorkerID.INTERNAL;
191+
}
177192
if (mlWorkerWSService.isWorkerConnected(workerID)) {
178193
// Initialize params
179194
MLWorkerWSRunModelParamDTO param = MLWorkerWSRunModelParamDTO.builder()

backend/src/main/java/ai/giskard/service/TestSuiteExecutionService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ public void executeScheduledTestSuite(TestSuiteExecution execution,
5757
}));
5858

5959
MLWorkerID workerID = suite.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
60+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
61+
workerID = MLWorkerID.INTERNAL;
62+
}
6063
if (mlWorkerWSService.isWorkerConnected(workerID)) {
6164
Map<String, FunctionInput> suiteInputsAndShared = Stream.concat(
6265
execution.getInputs().stream(),

backend/src/main/java/ai/giskard/service/TestSuiteService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ public Long generateTestSuite(String projectKey, GenerateTestSuiteDTO dto) {
202202

203203
Project project = projectRepository.getOneByKey(projectKey);
204204
MLWorkerID workerID = project.isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
205+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
206+
workerID = MLWorkerID.INTERNAL;
207+
}
205208
if (mlWorkerWSService.isWorkerConnected(workerID)) {
206209
MLWorkerWSGenerateTestSuiteParamDTO param = MLWorkerWSGenerateTestSuiteParamDTO.builder()
207210
.projectKey(projectKey)

backend/src/main/java/ai/giskard/web/rest/controllers/DatasetsController.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ public DatasetProcessingResultDTO datasetProcessing(@PathVariable("projectId") @
139139
.collect(Collectors.toMap(Callable::getUuid, Function.identity(), (l, r) -> l));
140140

141141
MLWorkerID workerID = project.isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
142+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
143+
workerID = MLWorkerID.INTERNAL;
144+
}
142145
if (!mlWorkerWSService.isWorkerConnected(workerID)) {
143146
throw new MLWorkerNotConnectedException(workerID, log);
144147
}

backend/src/main/java/ai/giskard/web/rest/controllers/MLWorkerController.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ai.giskard.ml.dto.MLWorkerWSBaseDTO;
66
import ai.giskard.ml.dto.MLWorkerWSGetInfoDTO;
77
import ai.giskard.ml.dto.MLWorkerWSGetInfoParamDTO;
8+
import ai.giskard.service.GeneralSettingsService;
89
import ai.giskard.service.ml.MLWorkerService;
910
import ai.giskard.service.ml.MLWorkerWSCommService;
1011
import ai.giskard.service.ml.MLWorkerWSService;
@@ -71,6 +72,9 @@ public List<MLWorkerInfoDTO> getMLWorkerInfo()
7172
@PostMapping("/stop")
7273
public void stopWorker(boolean internal) {
7374
MLWorkerID workerID = internal ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
75+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
76+
workerID = MLWorkerID.INTERNAL;
77+
}
7478
if (mlWorkerWSService.isWorkerConnected(workerID)) {
7579
mlWorkerWSCommService.performAction(workerID, MLWorkerWSAction.STOP_WORKER, null);
7680
}

backend/src/main/java/ai/giskard/web/rest/controllers/testing/TestController.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ai.giskard.ml.dto.MLWorkerWSRunAdHocTestParamDTO;
1515
import ai.giskard.repository.ProjectRepository;
1616
import ai.giskard.repository.ml.TestFunctionRepository;
17+
import ai.giskard.service.GeneralSettingsService;
1718
import ai.giskard.service.TestArgumentService;
1819
import ai.giskard.service.ml.MLWorkerWSCommService;
1920
import ai.giskard.service.ml.MLWorkerWSService;
@@ -50,6 +51,9 @@ public TestTemplateExecutionResultDTO runAdHocTest(@RequestBody RunAdhocTestRequ
5051
Project project = projectRepository.getMandatoryById(request.getProjectId());
5152

5253
MLWorkerID workerID = project.isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
54+
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES) {
55+
workerID = MLWorkerID.INTERNAL;
56+
}
5357
if (mlWorkerWSService.isWorkerConnected(workerID)) {
5458
MLWorkerWSRunAdHocTestParamDTO.MLWorkerWSRunAdHocTestParamDTOBuilder builder =
5559
MLWorkerWSRunAdHocTestParamDTO.builder()

backend/src/main/java/ai/giskard/web/socket/WorkerStatusSocketService.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ai.giskard.event.UpdateWorkerStatusEvent;
44
import ai.giskard.ml.MLWorkerID;
5+
import ai.giskard.service.GeneralSettingsService;
56
import ai.giskard.service.ml.MLWorkerWSService;
67
import lombok.RequiredArgsConstructor;
78
import org.springframework.context.event.EventListener;
@@ -33,7 +34,8 @@ public void handleWorkerStatusChangeEvent(UpdateWorkerStatusEvent event) {
3334

3435

3536
public void sendCurrentStatus() {
36-
MLWorkerID workerID = MLWorkerID.EXTERNAL;
37+
MLWorkerID workerID = GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES ?
38+
MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
3739
Map<String, Boolean> data = new HashMap<>();
3840
data.put("connected", mlWorkerWSService.isWorkerConnected(workerID));
3941
sendData(data);

0 commit comments

Comments
 (0)