Skip to content

Commit

Permalink
Merge pull request #81 from CSID-DGU/feature/#68/prediction
Browse files Browse the repository at this point in the history
[feat] : 딥러닝 가격 예측 조회 API 구현 (GET)
  • Loading branch information
bbbang105 authored Jun 18, 2024
2 parents e8f2eaf + 3e1dfa7 commit 138b486
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 6 deletions.
2 changes: 0 additions & 2 deletions backend/src/main/java/org/dgu/backend/BackendApplication.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableScheduling;

@SpringBootApplication
@EnableScheduling
public class BackendApplication {

public static void main(String[] args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ public enum SuccessStatus implements BaseCode {
// Trading
SUCCESS_START_TRADING(HttpStatus.CREATED, "201", "자동매매 등록에 성공했습니다."),
SUCCESS_DELETE_TRADING(HttpStatus.OK, "200", "자동매매 삭제에 성공했습니다."),
SUCCESS_GET_TRADING_LOGS(HttpStatus.OK, "200", "자동매매 거래 로그 조회에 성공했습니다.");
SUCCESS_GET_TRADING_LOGS(HttpStatus.OK, "200", "자동매매 거래 로그 조회에 성공했습니다."),
// Prediction
SUCCESS_GET_PREDICTIONS(HttpStatus.OK, "200", "딥러닝 가격 예측 값 조회에 성공했습니다.");

private final HttpStatus httpStatus;
private final String code;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.dgu.backend.controller;

import lombok.RequiredArgsConstructor;
import org.dgu.backend.common.ApiResponse;
import org.dgu.backend.common.constant.SuccessStatus;
import org.dgu.backend.dto.PredictionDto;
import org.dgu.backend.service.PredictionDataScheduler;
import org.dgu.backend.service.PredictionService;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;

@RestController
@RequestMapping("/api/v1/prediction")
@RequiredArgsConstructor
public class PredictionController {
private final PredictionService predictionService;
private final PredictionDataScheduler predictionDataScheduler;

// 딥러닝 가격 예측 값 조회 API
@GetMapping
public ResponseEntity<ApiResponse<List<PredictionDto.PredictionResponse>>> getPredictions() {

List<PredictionDto.PredictionResponse> predictionResponses = predictionService.getPredictions();
return ApiResponse.onSuccess(SuccessStatus.SUCCESS_GET_PREDICTIONS, predictionResponses);
}

// Train 수동 API
@GetMapping("/train")
public void startTrain() {

predictionDataScheduler.startTrain();
}

// 가격 예측 값 업데이트 수동 API
@GetMapping("/update")
public void getPrediction() {

predictionDataScheduler.getPrediction();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@RequiredArgsConstructor
public class TradingController {
private final TradingService tradingService;
private final UpbitAutoTrader upbitAutoTrader;

// 자동매매 등록 API
@PostMapping
Expand All @@ -37,12 +38,18 @@ public ResponseEntity<ApiResponse<Object>> removeAutoTrading(
return ApiResponse.onSuccess(SuccessStatus.SUCCESS_DELETE_TRADING);
}

// 자동매매 수동 테스트 API
// 자동매매 거래 로그 조회 API
@GetMapping("/logs")
public ResponseEntity<ApiResponse<List<TradingDto.TradingLog>>> getUserTradingLogs(
@RequestHeader("Authorization") String authorizationHeader) {

List<TradingDto.TradingLog> tradingLogs = tradingService.getUserTradingLogs(authorizationHeader);
return ApiResponse.onSuccess(SuccessStatus.SUCCESS_GET_TRADING_LOGS, tradingLogs);
}

// 자동매매 수동 테스트 API
@GetMapping("/test")
public void startTrading() {
upbitAutoTrader.performAutoTrading();
}
}
33 changes: 33 additions & 0 deletions backend/src/main/java/org/dgu/backend/domain/Prediction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.dgu.backend.domain;

import jakarta.persistence.*;
import lombok.*;
import org.dgu.backend.common.BaseEntity;

import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;

@Entity
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@AllArgsConstructor(access = AccessLevel.PROTECTED)
@Builder
@Getter
@Table(name = "predictions")
public class Prediction extends BaseEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "predictions_id")
private Long id;

@Column(name = "date")
private LocalDateTime date;

@Column(name = "close")
private Long close;

public Prediction(String epochTime, Long close) {
this.date = LocalDateTime.ofInstant(Instant.ofEpochMilli(Long.parseLong(epochTime)), ZoneOffset.UTC);
this.close = (long) Math.round(close);
}
}
38 changes: 38 additions & 0 deletions backend/src/main/java/org/dgu/backend/dto/PredictionDto.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.dgu.backend.dto;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import org.dgu.backend.domain.Prediction;

import java.util.ArrayList;
import java.util.List;

public class PredictionDto {
@Builder
@Getter
@AllArgsConstructor
@JsonNaming(value = PropertyNamingStrategies.SnakeCaseStrategy.class)
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class PredictionResponse {
@JsonProperty("date")
private String date;
@JsonProperty("close")
private Long close;

public static List<PredictionResponse> ofPredictions(List<Prediction> predictions) {
List<PredictionResponse> predictionResponses = new ArrayList<>();
for (Prediction prediction : predictions) {
predictionResponses.add(PredictionResponse.builder()
.date(String.valueOf(prediction.getDate()))
.close(prediction.getClose())
.build());
}
return predictionResponses;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import lombok.extern.slf4j.Slf4j;
import org.dgu.backend.common.ApiResponse;
import org.dgu.backend.common.code.BaseErrorCode;
import org.dgu.backend.domain.Market;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.MissingRequestHeaderException;
Expand Down Expand Up @@ -87,4 +86,10 @@ public ResponseEntity<ApiResponse<BaseErrorCode>> handleTradingException(Trading
TradingErrorResult errorResult = e.getTradingErrorResult();
return ApiResponse.onFailure(errorResult);
}
// Prediction
@ExceptionHandler(PredictionException.class)
public ResponseEntity<ApiResponse<BaseErrorCode>> handlePredictionException(PredictionException e) {
PredictionErrorResult errorResult = e.getPredictionErrorResult();
return ApiResponse.onFailure(errorResult);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.dgu.backend.exception;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.dgu.backend.common.code.BaseErrorCode;
import org.dgu.backend.common.dto.ErrorReasonDto;
import org.springframework.http.HttpStatus;

@Getter
@RequiredArgsConstructor
public enum PredictionErrorResult implements BaseErrorCode {
FAIL_TO_TRAINING(HttpStatus.NOT_FOUND, "404", "딥러닝 트레이닝에 실패했습니다."),
FAIL_TO_PREDICTION(HttpStatus.NOT_FOUND, "404", "딥러닝 가격 예측 데이터 받아 오기에 실패했습니다."),
FAIL_TO_PARSE_RESPONSE(HttpStatus.NOT_FOUND, "404", "가격 예측 데이터 파싱에 실패했습니다");

private final HttpStatus httpStatus;
private final String code;
private final String message;

@Override
public ErrorReasonDto getReason() {
return ErrorReasonDto.builder()
.isSuccess(false)
.code(code)
.message(message)
.build();
}

@Override
public ErrorReasonDto getReasonHttpStatus() {
return ErrorReasonDto.builder()
.isSuccess(false)
.httpStatus(httpStatus)
.code(code)
.message(message)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.dgu.backend.exception;

import lombok.Getter;
import lombok.RequiredArgsConstructor;

@Getter
@RequiredArgsConstructor
public class PredictionException extends RuntimeException {
private final PredictionErrorResult predictionErrorResult;

@Override
public String getMessage() {
return predictionErrorResult.getMessage();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.dgu.backend.repository;

import org.dgu.backend.domain.Prediction;
import org.springframework.data.jpa.repository.JpaRepository;

public interface PredictionRepository extends JpaRepository<Prediction,Long> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package org.dgu.backend.service;

import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.transaction.Transactional;
import lombok.RequiredArgsConstructor;
import org.dgu.backend.domain.Prediction;
import org.dgu.backend.dto.ChartDto;
import org.dgu.backend.dto.PredictionDto;
import org.dgu.backend.exception.PredictionErrorResult;
import org.dgu.backend.exception.PredictionException;
import org.dgu.backend.repository.PredictionRepository;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

@Component
@RequiredArgsConstructor
@Transactional
@EnableScheduling
public class PredictionDataScheduler {
@Value("${ai.url.train}")
private String AI_URL_TRAIN;
@Value("${ai.url.predict}")
private String AI_URL_PREDICT;
private final ChartService chartService;
private final PredictionRepository predictionRepository;
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;

// Train 실행 메서드
//@Scheduled(cron = "0 0 0 * * *") // 매일 00:00에 실행
public void startTrain() {
List<ChartDto.OHLCVResponse> ohlcvResponses = chartService.getOHLCVCharts("비트코인", "days", null);
// Train 요청
ResponseEntity<String> trainResponseEntity = restTemplate.exchange(
AI_URL_TRAIN,
HttpMethod.POST,
new HttpEntity<>(ohlcvResponses),
String.class
);
String trainMessage = trainResponseEntity.getBody();
if (Objects.isNull(trainMessage)) {
throw new PredictionException(PredictionErrorResult.FAIL_TO_TRAINING);
}
System.out.println("Train Message: " + trainMessage);

}

// Prediction 값을 받아오는 메서드
//@Scheduled(cron = "0 10 0 * * *") // 매일 00:10에 실행
public void getPrediction() {
List<ChartDto.OHLCVResponse> ohlcvResponses = chartService.getOHLCVCharts("비트코인", "days", null);
// Prediction 요청
ResponseEntity<String> predictResponseEntity = restTemplate.exchange(
AI_URL_PREDICT,
HttpMethod.POST,
new HttpEntity<>(ohlcvResponses),
String.class
);
if (Objects.isNull(predictResponseEntity.getBody())) {
throw new PredictionException(PredictionErrorResult.FAIL_TO_PREDICTION);
}
String responseBody = predictResponseEntity.getBody();

// JSON 문자열을 PredictionDto 배열로 변환
PredictionDto.PredictionResponse[] predictions;
try {
predictions = objectMapper.readValue(responseBody, PredictionDto.PredictionResponse[].class);
} catch (IOException e) {
throw new PredictionException(PredictionErrorResult.FAIL_TO_PARSE_RESPONSE);
}

// 기존 값 제거
List<Prediction> existPredictions = predictionRepository.findAll();
if (!Objects.isNull(existPredictions)) {
predictionRepository.deleteAll(existPredictions);
predictionRepository.flush();
}
// 변환된 데이터를 엔티티로 저장
Arrays.stream(predictions)
.map(prediction -> new Prediction(prediction.getDate(), prediction.getClose()))
.forEach(predictionRepository::save);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.dgu.backend.service;

import org.dgu.backend.dto.PredictionDto;

import java.util.List;

public interface PredictionService {
List<PredictionDto.PredictionResponse> getPredictions();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.dgu.backend.service;

import jakarta.transaction.Transactional;
import lombok.RequiredArgsConstructor;
import org.dgu.backend.domain.Prediction;
import org.dgu.backend.dto.PredictionDto;
import org.dgu.backend.repository.PredictionRepository;
import org.springframework.stereotype.Service;

import java.util.List;

@Service
@Transactional
@RequiredArgsConstructor
public class PredictionServiceImpl implements PredictionService {
private final PredictionRepository predictionRepository;

// 딥러닝 가격 예측 값 반환 메서드
@Override
public List<PredictionDto.PredictionResponse> getPredictions() {
List<Prediction> predictions = predictionRepository.findAll();
return PredictionDto.PredictionResponse.ofPredictions(predictions);
}
}
Loading

0 comments on commit 138b486

Please sign in to comment.