Browse Source

feat(alert): 新增 ANN 模型训练与测试功能

- 添加 ANN 模型训练接口 /alert/model/train/ann
- 添加 ANN 模型测试接口 /alert/model/test/ann
- 扩展 ModelInfo 和 ModelInitVO 支持 ANN 层级和迭代参数
- 更新模型服务接口和实现以支持 ANN 特定逻辑
- 重构训练参数结构以区分 PCA 与 ANN 配置
- 修复模型创建时算法相关字段设置问题
- 增强模型更新时对不同算法参数的处理逻辑
- 优化点数据结构增加 type 字段支持
pull/49/head
chenjiale 4 weeks ago
parent
commit
64a7580169
  1. 18
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java
  2. 8
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java
  3. 2
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java
  4. 10
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java
  5. 58
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java
  6. 6
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java
  7. 109
      yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java

18
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.alert.controller.admin.model;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*; import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*;
import cn.iocoder.yudao.module.alert.param.AnnTestParam;
import cn.iocoder.yudao.module.alert.param.AnnTrainParam;
import cn.iocoder.yudao.module.alert.param.ModelTestParam; import cn.iocoder.yudao.module.alert.param.ModelTestParam;
import cn.iocoder.yudao.module.alert.param.TrainParam; import cn.iocoder.yudao.module.alert.param.TrainParam;
import cn.iocoder.yudao.module.alert.service.model.ModelService; import cn.iocoder.yudao.module.alert.service.model.ModelService;
@ -72,12 +74,28 @@ public class ModelController {
} }
} }
@PostMapping("/train/ann")
public CommonResult<String> getAnnTrainData(@RequestBody AnnTrainParam param) {
try {
String trainInfo = modelService.trainAnn(param);
return CommonResult.success(trainInfo);
} catch (Exception e) {
return CommonResult.error(INTERNAL_SERVER_ERROR.getCode(), e.getMessage());
}
}
@PostMapping("/test") @PostMapping("/test")
public CommonResult<ModelTestData> getTestData(@RequestBody ModelTestParam param) { public CommonResult<ModelTestData> getTestData(@RequestBody ModelTestParam param) {
ModelTestData modelTestData = modelService.getModelTestData(param); ModelTestData modelTestData = modelService.getModelTestData(param);
return CommonResult.success(modelTestData); return CommonResult.success(modelTestData);
} }
@PostMapping("/test/ann")
public CommonResult<ModelTestData> getAnnTestData(@RequestBody AnnTestParam param) {
ModelTestData modelTestData = modelService.getAnnModelTestData(param);
return CommonResult.success(modelTestData);
}
/** /**
* 创建草稿版本v-test * 创建草稿版本v-test
*/ */

8
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java

@ -69,6 +69,12 @@ public class ModelInfo {
@JsonProperty("rate") @JsonProperty("rate")
private String rate; private String rate;
@JsonProperty("layer")
private String layer;
@JsonProperty("iter")
private Integer iter;
@JsonProperty("outPointInfo") @JsonProperty("outPointInfo")
private List<Point> outPointInfo; private List<Point> outPointInfo;
@ -102,7 +108,7 @@ public class ModelInfo {
private AlarmModelSet alarmModelSet; private AlarmModelSet alarmModelSet;
@JsonProperty("para") @JsonProperty("para")
private TrainInfo para; private Object para;
@JsonProperty("principal") @JsonProperty("principal")
private Integer principal; private Integer principal;

2
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java

@ -57,4 +57,6 @@ public class Point {
private BigDecimal C95; private BigDecimal C95;
@JsonProperty("C99") @JsonProperty("C99")
private BigDecimal C99; private BigDecimal C99;
private boolean type;
} }

10
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java

@ -57,4 +57,14 @@ public class ModelInitVO {
private Integer sampling; private Integer sampling;
private String rate; private String rate;
/**
* ANN 模型层级配置
*/
private String layer;
/**
* ANN 迭代次数
*/
private Integer iter;
} }

58
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java

@ -16,11 +16,36 @@ public class TrainParam {
@JsonProperty("Train_Data") @JsonProperty("Train_Data")
private TrainData trainData; private TrainData trainData;
/**
* 算法类型PCA / ANN
*/
private String type;
/**
* 条件保留原拼写以兼容旧入参
*/
private String conditon;
/**
* 正确拼写的条件字段兼容新入参
*/
@JsonAlias("condition")
private String condition;
/**
* PCA 专用配置
*/
private PcaParam pca;
/**
* ANN 专用配置如仍走旧接口可使用
*/
private AnnParam ann;
// === 旧字段兼容 ===
@JsonAlias("Hyper_para") @JsonAlias("Hyper_para")
@JsonProperty("Hyper_para") @JsonProperty("Hyper_para")
private HyperPara hyperPara; private HyperPara hyperPara;
private String type;
private String conditon;
@JsonAlias("smote_config") @JsonAlias("smote_config")
@JsonProperty("smote_config") @JsonProperty("smote_config")
@ -30,6 +55,10 @@ public class TrainParam {
private String targetPoint; private String targetPoint;
private String iter;
private String hide;
@Data @Data
public static class Smote { public static class Smote {
@ -57,4 +86,29 @@ public class TrainParam {
private String uplow; private String uplow;
private Integer interval; private Integer interval;
} }
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class PcaParam {
@JsonAlias("Hyper_para")
@JsonProperty("Hyper_para")
private HyperPara hyperPara;
@JsonAlias("smote_config")
@JsonProperty("smote_config")
private List<Object> smoteConfig;
private Boolean smote;
private String targetPoint;
}
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class AnnParam {
private String iter;
private String hide;
}
} }

6
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java

@ -2,7 +2,9 @@ package cn.iocoder.yudao.module.alert.service.model;
import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*; import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*;
import cn.iocoder.yudao.module.alert.param.AnnTestParam;
import cn.iocoder.yudao.module.alert.param.ModelTestParam; import cn.iocoder.yudao.module.alert.param.ModelTestParam;
import cn.iocoder.yudao.module.alert.param.AnnTrainParam;
import cn.iocoder.yudao.module.alert.param.TrainParam; import cn.iocoder.yudao.module.alert.param.TrainParam;
import java.util.List; import java.util.List;
@ -55,8 +57,12 @@ public interface ModelService {
TrainInfo trainModel(TrainParam param); TrainInfo trainModel(TrainParam param);
String trainAnn(AnnTrainParam param);
ModelTestData getModelTestData(ModelTestParam param); ModelTestData getModelTestData(ModelTestParam param);
ModelTestData getAnnModelTestData(AnnTestParam param);
/** /**
* 模型下装校验训练/评估后将模型状态置为已下装并落库版本 * 模型下装校验训练/评估后将模型状态置为已下装并落库版本
* *

109
yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java

@ -20,6 +20,8 @@ import cn.iocoder.yudao.module.alert.dao.mapper.AssessReportCfgMapper;
import cn.iocoder.yudao.module.alert.dao.service.ModelCfgService; import cn.iocoder.yudao.module.alert.dao.service.ModelCfgService;
import cn.iocoder.yudao.module.alert.dao.service.ModelVersionService; import cn.iocoder.yudao.module.alert.dao.service.ModelVersionService;
import cn.iocoder.yudao.module.alert.dao.service.SystemCfgService; import cn.iocoder.yudao.module.alert.dao.service.SystemCfgService;
import cn.iocoder.yudao.module.alert.param.AnnTestParam;
import cn.iocoder.yudao.module.alert.param.AnnTrainParam;
import cn.iocoder.yudao.module.alert.param.ModelTestParam; import cn.iocoder.yudao.module.alert.param.ModelTestParam;
import cn.iocoder.yudao.module.alert.param.TrainParam; import cn.iocoder.yudao.module.alert.param.TrainParam;
import cn.iocoder.yudao.module.alert.service.model.ModelService; import cn.iocoder.yudao.module.alert.service.model.ModelService;
@ -124,11 +126,12 @@ public class ModelServiceImpl implements ModelService {
@Override @Override
public Integer createModel(ModelInitVO model) { public Integer createModel(ModelInitVO model) {
try { try {
Algorithm algorithm = model.getAlgorithm();
ModelInfo info = new ModelInfo(); ModelInfo info = new ModelInfo();
BeanUtil.copyProperties(model, info); BeanUtil.copyProperties(model, info);
ModelCfg modelCfg = ModelCfg.builder() ModelCfg modelCfg = ModelCfg.builder()
.systemId(model.getSystemId()) .systemId(model.getSystemId())
.algorithmId(model.getAlgorithm().code) .algorithmId(algorithm.code)
.modelName(model.getModelName()) .modelName(model.getModelName())
.createTime(new Date()) .createTime(new Date())
.trash(ModelTrash.NORMAL.code) .trash(ModelTrash.NORMAL.code)
@ -141,13 +144,25 @@ public class ModelServiceImpl implements ModelService {
info.setId(String.valueOf(modelId)); info.setId(String.valueOf(modelId));
info.setSystemId(model.getSystemId()); info.setSystemId(model.getSystemId());
info.setUnit(model.getUnit()); info.setUnit(model.getUnit());
info.setAlgorithm(model.getAlgorithm()); info.setAlgorithm(algorithm);
info.setFounder(modelCfg.getCreator()); info.setFounder(modelCfg.getCreator());
info.setCreateTime(modelCfg.getCreateTime()); info.setCreateTime(modelCfg.getCreateTime());
info.setAlarmModelSet(ModelInfo.AlarmModelSet.defaultInit()); info.setAlarmModelSet(ModelInfo.AlarmModelSet.defaultInit());
info.setName(model.getModelName()); info.setName(model.getModelName());
info.setDescription(model.getDescription()); info.setDescription(model.getDescription());
info.setRate(model.getRate()); if (Algorithm.PCA.equals(algorithm)) {
info.setRate(model.getRate());
info.setLayer(null);
info.setIter(null);
} else if (Algorithm.ANN.equals(algorithm)) {
info.setLayer(model.getLayer());
info.setRate(null);
info.setIter(model.getIter());
} else {
info.setRate(model.getRate());
info.setLayer(model.getLayer());
info.setIter(model.getIter());
}
info.setTrainTime(new ArrayList<>()); info.setTrainTime(new ArrayList<>());
modelCfg.setModelInfo(JsonUtils.toJsonString(info)); modelCfg.setModelInfo(JsonUtils.toJsonString(info));
modelCfgService.updateById(modelCfg); modelCfgService.updateById(modelCfg);
@ -161,6 +176,19 @@ public class ModelServiceImpl implements ModelService {
@Override @Override
public Boolean updateModelInfo(ModelInfoVO modelInfo) { public Boolean updateModelInfo(ModelInfoVO modelInfo) {
ModelCfg existCfg = modelCfgService.getById(Integer.parseInt(modelInfo.getId()));
if (existCfg != null) {
Algorithm algorithm = Algorithm.of(existCfg.getAlgorithmId());
modelInfo.setAlgorithm(algorithm);
if (Algorithm.PCA.equals(algorithm)) {
modelInfo.setLayer(null);
modelInfo.setIter(null);
} else if (Algorithm.ANN.equals(algorithm)) {
modelInfo.setRate(null);
} else {
// 其他算法默认保留传入字段
}
}
modelInfo.setModifier(SecurityFrameworkUtils.getLoginUserNickname()); modelInfo.setModifier(SecurityFrameworkUtils.getLoginUserNickname());
modelInfo.setModifiedTime(new Date()); modelInfo.setModifiedTime(new Date());
ModelCfg modelCfg = ModelCfg.builder() ModelCfg modelCfg = ModelCfg.builder()
@ -196,7 +224,29 @@ public class ModelServiceImpl implements ModelService {
@Override @Override
public TrainInfo trainModel(TrainParam param) { public TrainInfo trainModel(TrainParam param) {
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(param)); if ("ANN".equalsIgnoreCase(param.getType())) {
throw new RuntimeException("请使用 /alert/model/train/ann 接口进行 ANN 训练");
}
return trainPca(param);
}
private TrainInfo trainPca(TrainParam param) {
Map<String, Object> payload = new HashMap<>();
payload.put("Train_Data", param.getTrainData());
payload.put("type", param.getType());
payload.put("condition", param.getCondition());
payload.put("conditon", param.getConditon());
TrainParam.PcaParam pca = param.getPca();
TrainParam.HyperPara hyperPara = pca != null && pca.getHyperPara() != null ? pca.getHyperPara() : param.getHyperPara();
if (hyperPara != null) {
payload.put("Hyper_para", hyperPara);
}
payload.put("smote_config", pca != null ? pca.getSmoteConfig() : param.getSmoteConfig());
payload.put("smote", pca != null && pca.getSmote() != null ? pca.getSmote() : param.getSmote());
payload.put("targetPoint", pca != null ? pca.getTargetPoint() : param.getTargetPoint());
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(payload));
if (trainBody.contains("error_msg")) { if (trainBody.contains("error_msg")) {
throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody, throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody,
new TypeReference<Map<String, String>>() { new TypeReference<Map<String, String>>() {
@ -206,6 +256,42 @@ public class ModelServiceImpl implements ModelService {
}); });
} }
@Override
public String trainAnn(AnnTrainParam param) {
Map<String, Object> payload = new HashMap<>();
TrainParam.TrainData trainData = param.getTrainData();
String time = trainData != null ? trainData.getTime() : null;
payload.put("time", time == null ? null : time.replaceAll(";+$", ""));
payload.put("point", trainData != null ? trainData.getPoints() : null);
Integer interval = trainData != null ? trainData.getInterval() : null;
payload.put("interval", interval == null ? null : interval * 1000);
String iter = param.getIter();
String hide = param.getHide();
payload.put("iter", iter);
payload.put("layer", StringUtils.hasText(hide) ? hide.split("-") : new String[]{});
payload.put("type", param.getType());
payload.put("dead", trainData != null ? trainData.getDead() : null);
payload.put("limit", trainData != null ? trainData.getLimit() : null);
payload.put("uplow", trainData != null ? trainData.getUplow() : null);
payload.put("condition", resolveCondition(param));
String trainBody = HttpUtils.post(algorithmHost + "/api/test/ANN_Train", null, JsonUtils.toJsonString(payload));
if (trainBody.contains("error_msg")) {
throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody,
new TypeReference<Map<String, String>>() {
}).get("error_msg"));
}
return trainBody;
}
private String resolveCondition(AnnTrainParam param) {
if (StringUtils.hasText(param.getCondition())) {
return param.getCondition();
}
return param.getConditon();
}
@Override @Override
public ModelTestData getModelTestData(ModelTestParam param) { public ModelTestData getModelTestData(ModelTestParam param) {
@ -214,6 +300,21 @@ public class ModelServiceImpl implements ModelService {
return JsonUtils.parseObject(result, ModelTestData.class); return JsonUtils.parseObject(result, ModelTestData.class);
} }
@Override
public ModelTestData getAnnModelTestData(AnnTestParam param) {
Map<String, Object> payload = new HashMap<>();
payload.put("point", param.getPoints());
payload.put("model", param.getModel());
payload.put("time", param.getTime());
Integer interval = param.getInterval();
payload.put("interval", interval == null ? null : interval);
payload.put("type", param.getType());
String result = HttpUtils.post(algorithmHost + "/api/test/ANN_Test", null, JsonUtils.toJsonString(payload))
.replace("NaN", "-1").replace("Infinity", "1");
return JsonUtils.parseObject(result, ModelTestData.class);
}
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public ModelInfoVO bottomModel(Integer id, Long reportId) { public ModelInfoVO bottomModel(Integer id, Long reportId) {

Loading…
Cancel
Save