97 lines
2.5 KiB
Python
97 lines
2.5 KiB
Python
"""
|
||
计算s_r,即当轮剩余可用时间总和
|
||
|
||
输出:
|
||
json格式
|
||
{
|
||
"GID": [s_0, s_1, ..., s_n]
|
||
}
|
||
"""
|
||
import json
|
||
|
||
from island.match import Match
|
||
from island.matches import Matches
|
||
|
||
|
||
class Sr:
|
||
def __init__(self):
|
||
self.details = {}
|
||
self.survivals = {}
|
||
with open('outputs/survivals_new.json', 'r') as f:
|
||
self.survivals = json.load(f)
|
||
self.seasons = [
|
||
Matches('wos-data-2022-pd')
|
||
]
|
||
self.results = {}
|
||
|
||
def get_s_r(self, m: Match, r: int):
|
||
"""
|
||
获取该轮所有剩余时间
|
||
S_r = 存活人数*1440 - 每个成功完成的博弈时间*2
|
||
|
||
:param m: Match
|
||
:param r: Round ID
|
||
:returns: s_r
|
||
"""
|
||
actions = m.query('action', 'done').where(lambda x: x['rno'] == r).raw_data
|
||
total_tr = len(self.survivals[m.name][str(r)]) * 1440
|
||
for act in actions:
|
||
total_tr -= act['tr'] * 2
|
||
return total_tr
|
||
|
||
def calc_season(self, season: Matches):
|
||
"""
|
||
calc s_r
|
||
"""
|
||
result = {}
|
||
for m in season.data:
|
||
game_end_at = int(m.query('game', 'created').first()['info']['game_end_at'])
|
||
sr = []
|
||
for r in range(1, game_end_at + 1):
|
||
sr.append(self.get_s_r(m, r))
|
||
result[m.name] = sr
|
||
return result
|
||
|
||
def calc(self):
|
||
result = {}
|
||
for s in self.seasons:
|
||
result.update(self.calc_season(s))
|
||
with open('outputs/S_R.json', 'w') as f:
|
||
json.dump(result, f)
|
||
self.results = result
|
||
|
||
def save_plot(self, name: str):
|
||
if name not in self.results:
|
||
print(f'{name} not found')
|
||
return
|
||
|
||
import numpy as np
|
||
from matplotlib import pyplot as plt
|
||
|
||
result = self.results[name]
|
||
x = np.arange(1, len(result) + 1)
|
||
fig = plt.figure(figsize=(6, 3))
|
||
ax = fig.gca()
|
||
ax.plot(x, result, 'o--', color='limegreen', linewidth=2, label=r"S_r")
|
||
|
||
ax.tick_params(labelsize=14)
|
||
ax.set_xlim(1, len(result))
|
||
ax.set_xticks(x[::2])
|
||
ax.set_xlabel("Rounds", size=22)
|
||
ax.set_ylabel(r"$S_r$", family='sans-serif', size=22)
|
||
|
||
plt.legend(numpoints=2, fontsize=14)
|
||
plt.tight_layout()
|
||
# plt.show()
|
||
# plt.savefig("graph/theta_plot.eps")
|
||
plt.savefig(f'graph/s_r_{name}.pdf')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
e = Sr()
|
||
e.calc()
|
||
e.save_plot("G646")
|
||
e.save_plot("G903")
|
||
e.save_plot("G936")
|
||
e.save_plot("G933")
|