Newer
Older
BrainWave-Task-Visualizer / BrainWave_Pred.py
  1. import numpy as np
  2. from joblib import load
  3. import os
  4. # from __future__ import unicode_literals, print_function
  5. from socket import socket, AF_INET, SOCK_DGRAM
  6. from pythonosc import osc_message
  7. import time
  8.  
  9. HOST = ''
  10. PORT = 8001
  11. class EEGPredictor:
  12. def __init__(self, model_dir='trained_model'):
  13. self.model, self.scaler = self._load_model(model_dir)
  14. self.class_names = {0: 'neutral', 1: 'right', 2: 'left'}
  15. self.feature_names = [
  16. 'alpha', 'beta', 'theta', 'delta', 'gamma',
  17. 'alpha_beta_ratio', 'theta_beta_ratio', 'alpha_theta_ratio',
  18. 'alpha_rel', 'beta_rel', 'theta_rel', 'delta_rel', 'gamma_rel',
  19. 'alpha_log', 'beta_log', 'theta_log', 'delta_log', 'gamma_log'
  20. ]
  21.  
  22. def _load_model(self, model_dir):
  23. model_path = os.path.join(model_dir, 'rf_model.joblib')
  24. scaler_path = os.path.join(model_dir, 'scaler.joblib')
  25. return load(model_path), load(scaler_path)
  26.  
  27. def _create_features(self, basic_features):
  28. """基本的な脳波データから追加の特徴量を生成"""
  29. features = {}
  30. # 基本特徴量
  31. alpha, beta, theta, delta, gamma = basic_features
  32. features.update({
  33. 'alpha': alpha, 'beta': beta, 'theta': theta,
  34. 'delta': delta, 'gamma': gamma
  35. })
  36. # 比率の計算
  37. features['alpha_beta_ratio'] = alpha / beta if beta != 0 else 0
  38. features['theta_beta_ratio'] = theta / beta if beta != 0 else 0
  39. features['alpha_theta_ratio'] = alpha / theta if theta != 0 else 0
  40. # 相対パワーの計算
  41. total_power = sum(basic_features)
  42. if total_power != 0:
  43. features.update({
  44. 'alpha_rel': alpha / total_power,
  45. 'beta_rel': beta / total_power,
  46. 'theta_rel': theta / total_power,
  47. 'delta_rel': delta / total_power,
  48. 'gamma_rel': gamma / total_power
  49. })
  50. else:
  51. features.update({
  52. 'alpha_rel': 0, 'beta_rel': 0, 'theta_rel': 0,
  53. 'delta_rel': 0, 'gamma_rel': 0
  54. })
  55. # 対数変換
  56. features.update({
  57. 'alpha_log': np.log1p(alpha) if alpha > 0 else 0,
  58. 'beta_log': np.log1p(beta) if beta > 0 else 0,
  59. 'theta_log': np.log1p(theta) if theta > 0 else 0,
  60. 'delta_log': np.log1p(delta) if delta > 0 else 0,
  61. 'gamma_log': np.log1p(gamma) if gamma > 0 else 0
  62. })
  63. # 特徴量を正しい順序で並べ替え
  64. return [features[name] for name in self.feature_names]
  65.  
  66. def predict(self, eeg_data):
  67. """脳波データから予測を行う"""
  68. # 特徴量の生成
  69. features = self._create_features(eeg_data)
  70. features = np.array(features).reshape(1, -1)
  71. # スケーリングと予測
  72. scaled_data = self.scaler.transform(features)
  73. prediction = self.model.predict(scaled_data)[0]
  74. probabilities = self.model.predict_proba(scaled_data)[0]
  75. # 結果の整形
  76. result = {
  77. 'predicted_class': self.class_names[prediction],
  78. 'confidence': float(probabilities[prediction]),
  79. 'probabilities': {
  80. self.class_names[i]: float(prob)
  81. for i, prob in enumerate(probabilities)
  82. }
  83. }
  84. return result
  85. def process_eeg_data(raw_eeg_data):
  86. """
  87. 脳波データを処理して予測を行う
  88. Parameters:
  89. raw_eeg_data: [alpha, beta, theta, delta, gamma]の形式の脳波データ
  90. """
  91. predictor = EEGPredictor()
  92. result = predictor.predict(raw_eeg_data)
  93. print(f"Predicted class: {result['predicted_class']}")
  94. print(f"Confidence: {result['confidence']:.2f}")
  95. print("Class probabilities:")
  96. for class_name, prob in result['probabilities'].items():
  97. print(f" {class_name}: {prob:.2f}")
  98. return result
  99.  
  100. def Convert_BrainWave(data):
  101. msg = osc_message.OscMessage(data)
  102. #print(msg.params)
  103. types = msg.address
  104. arguments = []
  105. if types == "/Attention":
  106. arguments.append("Attention")
  107. arguments.append(float(msg.params[0]))
  108. elif types == "/Meditation":
  109. arguments.append("Meditation")
  110. arguments.append(float(msg.params[0]))
  111. elif types == "/BandPower":
  112. arguments.append("BandPower")
  113. arguments += list(map(float, msg.params[0].split(";")))
  114. #print(map(float, msg.params[0].split(";")))
  115. return arguments
  116. s = socket(AF_INET, SOCK_DGRAM)
  117. s.bind((HOST, PORT))
  118. while True:
  119. # 実際の脳波データをここに入れます
  120. print("受信待ち")
  121. data, address = s.recvfrom(1024)
  122. received_data = Convert_BrainWave(data=data)
  123. received_data = received_data[1:]
  124. print(received_data)
  125. result = process_eeg_data(received_data)
  126. time.sleep(0.5)