|
|
|
@ -20,6 +20,8 @@ import cn.iocoder.yudao.module.alert.dao.mapper.AssessReportCfgMapper; |
|
|
|
import cn.iocoder.yudao.module.alert.dao.service.ModelCfgService; |
|
|
|
import cn.iocoder.yudao.module.alert.dao.service.ModelVersionService; |
|
|
|
import cn.iocoder.yudao.module.alert.dao.service.SystemCfgService; |
|
|
|
import cn.iocoder.yudao.module.alert.param.AnnTestParam; |
|
|
|
import cn.iocoder.yudao.module.alert.param.AnnTrainParam; |
|
|
|
import cn.iocoder.yudao.module.alert.param.ModelTestParam; |
|
|
|
import cn.iocoder.yudao.module.alert.param.TrainParam; |
|
|
|
import cn.iocoder.yudao.module.alert.service.model.ModelService; |
|
|
|
@ -124,11 +126,12 @@ public class ModelServiceImpl implements ModelService { |
|
|
|
@Override |
|
|
|
public Integer createModel(ModelInitVO model) { |
|
|
|
try { |
|
|
|
Algorithm algorithm = model.getAlgorithm(); |
|
|
|
ModelInfo info = new ModelInfo(); |
|
|
|
BeanUtil.copyProperties(model, info); |
|
|
|
ModelCfg modelCfg = ModelCfg.builder() |
|
|
|
.systemId(model.getSystemId()) |
|
|
|
.algorithmId(model.getAlgorithm().code) |
|
|
|
.algorithmId(algorithm.code) |
|
|
|
.modelName(model.getModelName()) |
|
|
|
.createTime(new Date()) |
|
|
|
.trash(ModelTrash.NORMAL.code) |
|
|
|
@ -141,13 +144,25 @@ public class ModelServiceImpl implements ModelService { |
|
|
|
info.setId(String.valueOf(modelId)); |
|
|
|
info.setSystemId(model.getSystemId()); |
|
|
|
info.setUnit(model.getUnit()); |
|
|
|
info.setAlgorithm(model.getAlgorithm()); |
|
|
|
info.setAlgorithm(algorithm); |
|
|
|
info.setFounder(modelCfg.getCreator()); |
|
|
|
info.setCreateTime(modelCfg.getCreateTime()); |
|
|
|
info.setAlarmModelSet(ModelInfo.AlarmModelSet.defaultInit()); |
|
|
|
info.setName(model.getModelName()); |
|
|
|
info.setDescription(model.getDescription()); |
|
|
|
if (Algorithm.PCA.equals(algorithm)) { |
|
|
|
info.setRate(model.getRate()); |
|
|
|
info.setLayer(null); |
|
|
|
info.setIter(null); |
|
|
|
} else if (Algorithm.ANN.equals(algorithm)) { |
|
|
|
info.setLayer(model.getLayer()); |
|
|
|
info.setRate(null); |
|
|
|
info.setIter(model.getIter()); |
|
|
|
} else { |
|
|
|
info.setRate(model.getRate()); |
|
|
|
info.setLayer(model.getLayer()); |
|
|
|
info.setIter(model.getIter()); |
|
|
|
} |
|
|
|
info.setTrainTime(new ArrayList<>()); |
|
|
|
modelCfg.setModelInfo(JsonUtils.toJsonString(info)); |
|
|
|
modelCfgService.updateById(modelCfg); |
|
|
|
@ -161,6 +176,19 @@ public class ModelServiceImpl implements ModelService { |
|
|
|
|
|
|
|
@Override |
|
|
|
public Boolean updateModelInfo(ModelInfoVO modelInfo) { |
|
|
|
ModelCfg existCfg = modelCfgService.getById(Integer.parseInt(modelInfo.getId())); |
|
|
|
if (existCfg != null) { |
|
|
|
Algorithm algorithm = Algorithm.of(existCfg.getAlgorithmId()); |
|
|
|
modelInfo.setAlgorithm(algorithm); |
|
|
|
if (Algorithm.PCA.equals(algorithm)) { |
|
|
|
modelInfo.setLayer(null); |
|
|
|
modelInfo.setIter(null); |
|
|
|
} else if (Algorithm.ANN.equals(algorithm)) { |
|
|
|
modelInfo.setRate(null); |
|
|
|
} else { |
|
|
|
// 其他算法默认保留传入字段
|
|
|
|
} |
|
|
|
} |
|
|
|
modelInfo.setModifier(SecurityFrameworkUtils.getLoginUserNickname()); |
|
|
|
modelInfo.setModifiedTime(new Date()); |
|
|
|
ModelCfg modelCfg = ModelCfg.builder() |
|
|
|
@ -196,7 +224,29 @@ public class ModelServiceImpl implements ModelService { |
|
|
|
|
|
|
|
@Override |
|
|
|
public TrainInfo trainModel(TrainParam param) { |
|
|
|
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(param)); |
|
|
|
if ("ANN".equalsIgnoreCase(param.getType())) { |
|
|
|
throw new RuntimeException("请使用 /alert/model/train/ann 接口进行 ANN 训练"); |
|
|
|
} |
|
|
|
return trainPca(param); |
|
|
|
} |
|
|
|
|
|
|
|
private TrainInfo trainPca(TrainParam param) { |
|
|
|
Map<String, Object> payload = new HashMap<>(); |
|
|
|
payload.put("Train_Data", param.getTrainData()); |
|
|
|
payload.put("type", param.getType()); |
|
|
|
payload.put("condition", param.getCondition()); |
|
|
|
payload.put("conditon", param.getConditon()); |
|
|
|
|
|
|
|
TrainParam.PcaParam pca = param.getPca(); |
|
|
|
TrainParam.HyperPara hyperPara = pca != null && pca.getHyperPara() != null ? pca.getHyperPara() : param.getHyperPara(); |
|
|
|
if (hyperPara != null) { |
|
|
|
payload.put("Hyper_para", hyperPara); |
|
|
|
} |
|
|
|
payload.put("smote_config", pca != null ? pca.getSmoteConfig() : param.getSmoteConfig()); |
|
|
|
payload.put("smote", pca != null && pca.getSmote() != null ? pca.getSmote() : param.getSmote()); |
|
|
|
payload.put("targetPoint", pca != null ? pca.getTargetPoint() : param.getTargetPoint()); |
|
|
|
|
|
|
|
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(payload)); |
|
|
|
if (trainBody.contains("error_msg")) { |
|
|
|
throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody, |
|
|
|
new TypeReference<Map<String, String>>() { |
|
|
|
@ -206,6 +256,42 @@ public class ModelServiceImpl implements ModelService { |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
@Override |
|
|
|
public String trainAnn(AnnTrainParam param) { |
|
|
|
Map<String, Object> payload = new HashMap<>(); |
|
|
|
TrainParam.TrainData trainData = param.getTrainData(); |
|
|
|
String time = trainData != null ? trainData.getTime() : null; |
|
|
|
payload.put("time", time == null ? null : time.replaceAll(";+$", "")); |
|
|
|
payload.put("point", trainData != null ? trainData.getPoints() : null); |
|
|
|
Integer interval = trainData != null ? trainData.getInterval() : null; |
|
|
|
payload.put("interval", interval == null ? null : interval * 1000); |
|
|
|
|
|
|
|
String iter = param.getIter(); |
|
|
|
String hide = param.getHide(); |
|
|
|
|
|
|
|
payload.put("iter", iter); |
|
|
|
payload.put("layer", StringUtils.hasText(hide) ? hide.split("-") : new String[]{}); |
|
|
|
payload.put("type", param.getType()); |
|
|
|
payload.put("dead", trainData != null ? trainData.getDead() : null); |
|
|
|
payload.put("limit", trainData != null ? trainData.getLimit() : null); |
|
|
|
payload.put("uplow", trainData != null ? trainData.getUplow() : null); |
|
|
|
payload.put("condition", resolveCondition(param)); |
|
|
|
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ANN_Train", null, JsonUtils.toJsonString(payload)); |
|
|
|
if (trainBody.contains("error_msg")) { |
|
|
|
throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody, |
|
|
|
new TypeReference<Map<String, String>>() { |
|
|
|
}).get("error_msg")); |
|
|
|
} |
|
|
|
return trainBody; |
|
|
|
} |
|
|
|
|
|
|
|
private String resolveCondition(AnnTrainParam param) { |
|
|
|
if (StringUtils.hasText(param.getCondition())) { |
|
|
|
return param.getCondition(); |
|
|
|
} |
|
|
|
return param.getConditon(); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
public ModelTestData getModelTestData(ModelTestParam param) { |
|
|
|
@ -214,6 +300,21 @@ public class ModelServiceImpl implements ModelService { |
|
|
|
return JsonUtils.parseObject(result, ModelTestData.class); |
|
|
|
} |
|
|
|
|
|
|
|
@Override |
|
|
|
public ModelTestData getAnnModelTestData(AnnTestParam param) { |
|
|
|
Map<String, Object> payload = new HashMap<>(); |
|
|
|
payload.put("point", param.getPoints()); |
|
|
|
payload.put("model", param.getModel()); |
|
|
|
payload.put("time", param.getTime()); |
|
|
|
Integer interval = param.getInterval(); |
|
|
|
payload.put("interval", interval == null ? null : interval); |
|
|
|
payload.put("type", param.getType()); |
|
|
|
|
|
|
|
String result = HttpUtils.post(algorithmHost + "/api/test/ANN_Test", null, JsonUtils.toJsonString(payload)) |
|
|
|
.replace("NaN", "-1").replace("Infinity", "1"); |
|
|
|
return JsonUtils.parseObject(result, ModelTestData.class); |
|
|
|
} |
|
|
|
|
|
|
|
@Override |
|
|
|
@Transactional(rollbackFor = Exception.class) |
|
|
|
public ModelInfoVO bottomModel(Integer id, Long reportId) { |
|
|
|
|