Browse Source

Merge branch 'master' into dev-xjf

pull/50/head
xjf 4 weeks ago
parent
commit
198647e38f
  1. 18
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java
  2. 8
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java
  3. 2
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java
  4. 10
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java
  5. 40
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/AnnTestParam.java
  6. 46
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/AnnTrainParam.java
  7. 58
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java
  8. 6
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java
  9. 107
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java

18
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.alert.controller.admin.model;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*;
import cn.iocoder.yudao.module.alert.param.AnnTestParam;
import cn.iocoder.yudao.module.alert.param.AnnTrainParam;
import cn.iocoder.yudao.module.alert.param.ModelTestParam;
import cn.iocoder.yudao.module.alert.param.TrainParam;
import cn.iocoder.yudao.module.alert.service.model.ModelService;
@ -72,12 +74,28 @@ public class ModelController {
}
}
@PostMapping("/train/ann")
public CommonResult<String> getAnnTrainData(@RequestBody AnnTrainParam param) {
try {
String trainInfo = modelService.trainAnn(param);
return CommonResult.success(trainInfo);
} catch (Exception e) {
return CommonResult.error(INTERNAL_SERVER_ERROR.getCode(), e.getMessage());
}
}
@PostMapping("/test")
public CommonResult<ModelTestData> getTestData(@RequestBody ModelTestParam param) {
ModelTestData modelTestData = modelService.getModelTestData(param);
return CommonResult.success(modelTestData);
}
@PostMapping("/test/ann")
public CommonResult<ModelTestData> getAnnTestData(@RequestBody AnnTestParam param) {
ModelTestData modelTestData = modelService.getAnnModelTestData(param);
return CommonResult.success(modelTestData);
}
/**
* 创建草稿版本v-test
*/

8
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java

@ -69,6 +69,12 @@ public class ModelInfo {
@JsonProperty("rate")
private String rate;
@JsonProperty("layer")
private String layer;
@JsonProperty("iter")
private Integer iter;
@JsonProperty("outPointInfo")
private List<Point> outPointInfo;
@ -102,7 +108,7 @@ public class ModelInfo {
private AlarmModelSet alarmModelSet;
@JsonProperty("para")
private TrainInfo para;
private Object para;
@JsonProperty("principal")
private Integer principal;

2
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java

@ -57,4 +57,6 @@ public class Point {
private BigDecimal C95;
@JsonProperty("C99")
private BigDecimal C99;
private boolean type;
}

10
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java

@ -57,4 +57,14 @@ public class ModelInitVO {
private Integer sampling;
private String rate;
/**
* ANN 模型层级配置
*/
private String layer;
/**
* ANN 迭代次数
*/
private Integer iter;
}

40
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/AnnTestParam.java

@ -0,0 +1,40 @@
package cn.iocoder.yudao.module.alert.param;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* ANN 测试入参
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class AnnTestParam {
/**
* 时间区间格式与原 PCA 测试保持一致
*/
private String time;
/**
* 点位集合
*/
private String points;
/**
* 采样间隔下游需转毫秒
*/
private Integer interval;
/**
* 模型内容字符串
*/
private String model;
/**
* 算法类型建议传 ANN
*/
private String type;
}

46
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/AnnTrainParam.java

@ -0,0 +1,46 @@
package cn.iocoder.yudao.module.alert.param;
import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* ANN 训练入参
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class AnnTrainParam {
@JsonAlias("Train_Data")
@JsonProperty("Train_Data")
private TrainParam.TrainData trainData;
/**
* 算法类型建议传 ANN
*/
private String type;
/**
* ANN 迭代次数
*/
private String iter;
/**
* ANN 隐层结构使用-分隔
*/
private String hide;
/**
* 条件保留原拼写以兼容旧入参
*/
private String conditon;
/**
* 正确拼写的条件字段兼容新入参
*/
@JsonAlias("condition")
private String condition;
}

58
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java

@ -16,11 +16,36 @@ public class TrainParam {
@JsonProperty("Train_Data")
private TrainData trainData;
/**
* 算法类型PCA / ANN
*/
private String type;
/**
* 条件保留原拼写以兼容旧入参
*/
private String conditon;
/**
* 正确拼写的条件字段兼容新入参
*/
@JsonAlias("condition")
private String condition;
/**
* PCA 专用配置
*/
private PcaParam pca;
/**
* ANN 专用配置如仍走旧接口可使用
*/
private AnnParam ann;
// === 旧字段兼容 ===
@JsonAlias("Hyper_para")
@JsonProperty("Hyper_para")
private HyperPara hyperPara;
private String type;
private String conditon;
@JsonAlias("smote_config")
@JsonProperty("smote_config")
@ -30,6 +55,10 @@ public class TrainParam {
private String targetPoint;
private String iter;
private String hide;
@Data
public static class Smote {
@ -57,4 +86,29 @@ public class TrainParam {
private String uplow;
private Integer interval;
}
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class PcaParam {
@JsonAlias("Hyper_para")
@JsonProperty("Hyper_para")
private HyperPara hyperPara;
@JsonAlias("smote_config")
@JsonProperty("smote_config")
private List<Object> smoteConfig;
private Boolean smote;
private String targetPoint;
}
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class AnnParam {
private String iter;
private String hide;
}
}

6
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java

@ -2,7 +2,9 @@ package cn.iocoder.yudao.module.alert.service.model;
import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*;
import cn.iocoder.yudao.module.alert.param.AnnTestParam;
import cn.iocoder.yudao.module.alert.param.ModelTestParam;
import cn.iocoder.yudao.module.alert.param.AnnTrainParam;
import cn.iocoder.yudao.module.alert.param.TrainParam;
import java.util.List;
@ -55,8 +57,12 @@ public interface ModelService {
TrainInfo trainModel(TrainParam param);
String trainAnn(AnnTrainParam param);
ModelTestData getModelTestData(ModelTestParam param);
ModelTestData getAnnModelTestData(AnnTestParam param);
/**
* 模型下装校验训练/评估后将模型状态置为已下装并落库版本
*

107
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java

@ -20,6 +20,8 @@ import cn.iocoder.yudao.module.alert.dao.mapper.AssessReportCfgMapper;
import cn.iocoder.yudao.module.alert.dao.service.ModelCfgService;
import cn.iocoder.yudao.module.alert.dao.service.ModelVersionService;
import cn.iocoder.yudao.module.alert.dao.service.SystemCfgService;
import cn.iocoder.yudao.module.alert.param.AnnTestParam;
import cn.iocoder.yudao.module.alert.param.AnnTrainParam;
import cn.iocoder.yudao.module.alert.param.ModelTestParam;
import cn.iocoder.yudao.module.alert.param.TrainParam;
import cn.iocoder.yudao.module.alert.service.model.ModelService;
@ -124,11 +126,12 @@ public class ModelServiceImpl implements ModelService {
@Override
public Integer createModel(ModelInitVO model) {
try {
Algorithm algorithm = model.getAlgorithm();
ModelInfo info = new ModelInfo();
BeanUtil.copyProperties(model, info);
ModelCfg modelCfg = ModelCfg.builder()
.systemId(model.getSystemId())
.algorithmId(model.getAlgorithm().code)
.algorithmId(algorithm.code)
.modelName(model.getModelName())
.createTime(new Date())
.trash(ModelTrash.NORMAL.code)
@ -141,13 +144,25 @@ public class ModelServiceImpl implements ModelService {
info.setId(String.valueOf(modelId));
info.setSystemId(model.getSystemId());
info.setUnit(model.getUnit());
info.setAlgorithm(model.getAlgorithm());
info.setAlgorithm(algorithm);
info.setFounder(modelCfg.getCreator());
info.setCreateTime(modelCfg.getCreateTime());
info.setAlarmModelSet(ModelInfo.AlarmModelSet.defaultInit());
info.setName(model.getModelName());
info.setDescription(model.getDescription());
if (Algorithm.PCA.equals(algorithm)) {
info.setRate(model.getRate());
info.setLayer(null);
info.setIter(null);
} else if (Algorithm.ANN.equals(algorithm)) {
info.setLayer(model.getLayer());
info.setRate(null);
info.setIter(model.getIter());
} else {
info.setRate(model.getRate());
info.setLayer(model.getLayer());
info.setIter(model.getIter());
}
info.setTrainTime(new ArrayList<>());
modelCfg.setModelInfo(JsonUtils.toJsonString(info));
modelCfgService.updateById(modelCfg);
@ -161,6 +176,19 @@ public class ModelServiceImpl implements ModelService {
@Override
public Boolean updateModelInfo(ModelInfoVO modelInfo) {
ModelCfg existCfg = modelCfgService.getById(Integer.parseInt(modelInfo.getId()));
if (existCfg != null) {
Algorithm algorithm = Algorithm.of(existCfg.getAlgorithmId());
modelInfo.setAlgorithm(algorithm);
if (Algorithm.PCA.equals(algorithm)) {
modelInfo.setLayer(null);
modelInfo.setIter(null);
} else if (Algorithm.ANN.equals(algorithm)) {
modelInfo.setRate(null);
} else {
// 其他算法默认保留传入字段
}
}
modelInfo.setModifier(SecurityFrameworkUtils.getLoginUserNickname());
modelInfo.setModifiedTime(new Date());
ModelCfg modelCfg = ModelCfg.builder()
@ -196,7 +224,29 @@ public class ModelServiceImpl implements ModelService {
@Override
public TrainInfo trainModel(TrainParam param) {
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(param));
if ("ANN".equalsIgnoreCase(param.getType())) {
throw new RuntimeException("请使用 /alert/model/train/ann 接口进行 ANN 训练");
}
return trainPca(param);
}
private TrainInfo trainPca(TrainParam param) {
Map<String, Object> payload = new HashMap<>();
payload.put("Train_Data", param.getTrainData());
payload.put("type", param.getType());
payload.put("condition", param.getCondition());
payload.put("conditon", param.getConditon());
TrainParam.PcaParam pca = param.getPca();
TrainParam.HyperPara hyperPara = pca != null && pca.getHyperPara() != null ? pca.getHyperPara() : param.getHyperPara();
if (hyperPara != null) {
payload.put("Hyper_para", hyperPara);
}
payload.put("smote_config", pca != null ? pca.getSmoteConfig() : param.getSmoteConfig());
payload.put("smote", pca != null && pca.getSmote() != null ? pca.getSmote() : param.getSmote());
payload.put("targetPoint", pca != null ? pca.getTargetPoint() : param.getTargetPoint());
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(payload));
if (trainBody.contains("error_msg")) {
throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody,
new TypeReference<Map<String, String>>() {
@ -206,6 +256,42 @@ public class ModelServiceImpl implements ModelService {
});
}
@Override
public String trainAnn(AnnTrainParam param) {
Map<String, Object> payload = new HashMap<>();
TrainParam.TrainData trainData = param.getTrainData();
String time = trainData != null ? trainData.getTime() : null;
payload.put("time", time == null ? null : time.replaceAll(";+$", ""));
payload.put("point", trainData != null ? trainData.getPoints() : null);
Integer interval = trainData != null ? trainData.getInterval() : null;
payload.put("interval", interval == null ? null : interval * 1000);
String iter = param.getIter();
String hide = param.getHide();
payload.put("iter", iter);
payload.put("layer", StringUtils.hasText(hide) ? hide.split("-") : new String[]{});
payload.put("type", param.getType());
payload.put("dead", trainData != null ? trainData.getDead() : null);
payload.put("limit", trainData != null ? trainData.getLimit() : null);
payload.put("uplow", trainData != null ? trainData.getUplow() : null);
payload.put("condition", resolveCondition(param));
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ANN_Train", null, JsonUtils.toJsonString(payload));
if (trainBody.contains("error_msg")) {
throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody,
new TypeReference<Map<String, String>>() {
}).get("error_msg"));
}
return trainBody;
}
private String resolveCondition(AnnTrainParam param) {
if (StringUtils.hasText(param.getCondition())) {
return param.getCondition();
}
return param.getConditon();
}
@Override
public ModelTestData getModelTestData(ModelTestParam param) {
@ -214,6 +300,21 @@ public class ModelServiceImpl implements ModelService {
return JsonUtils.parseObject(result, ModelTestData.class);
}
@Override
public ModelTestData getAnnModelTestData(AnnTestParam param) {
Map<String, Object> payload = new HashMap<>();
payload.put("point", param.getPoints());
payload.put("model", param.getModel());
payload.put("time", param.getTime());
Integer interval = param.getInterval();
payload.put("interval", interval == null ? null : interval);
payload.put("type", param.getType());
String result = HttpUtils.post(algorithmHost + "/api/test/ANN_Test", null, JsonUtils.toJsonString(payload))
.replace("NaN", "-1").replace("Infinity", "1");
return JsonUtils.parseObject(result, ModelTestData.class);
}
@Override
@Transactional(rollbackFor = Exception.class)
public ModelInfoVO bottomModel(Integer id, Long reportId) {

Loading…
Cancel
Save