swdata/theta_r.py
2023-03-15 21:16:16 +08:00

135 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
计算theta_r
theta_r_i=所有邻居的剩余时间
输出:
1. json格式(详细)
{
"GID": {
"RID": {
"PID": theta_r_i,
}
}
}
2. CSV格式(按轮平均值, CLASSIC/SURVIVE分别对应一个文件)
"""
import csv
import json
from functools import reduce
from sre_constants import MAX_REPEAT
from tkinter.tix import MAX
import numpy as np
from island.match import Match
from island.matches import Matches
MAX_ROUND = 28 # game_end_at
class theta_r_i:
def __init__(self):
self.details = {}
self.survivals = {}
with open('outputs/survivals_new.json', 'r') as f:
self.survivals = json.load(f)
self.neighbors = {}
with open('outputs/neighborhood_new.json', 'r') as f:
self.neighbors = json.load(f)
self.seasons = [
dict(season=Matches('wos-data-2022-pd', network_type='BA'), name='NEW_BA'),
dict(season=Matches('wos-data-2022-pd', network_type='WS'), name='NEW_WS')
]
# self.seasonSurvive = Matches.from_profile('SURVIVE')
# self.seasonClassic = Matches.from_profile('CLASSIC')
def getNeighborTR(self, m, r, p, s):
"""
获取该玩家所有邻居的剩余时间
:param m: Match
:param r: Round ID
:param p: PID
:param s: Survivals list(neighborhood)
:returns: theta_{r}_{i}
"""
actions = m.query('action', 'done') \
.where(lambda x: x['rno'] == r and ((x['a'] in s) or (x['b'] in s))) \
.raw_data
def reduce_func(total, item):
if item['a'] in s and item['b'] in s:
return total + item['tr'] * 2
return total + item['tr']
return 1440 * len(s) - reduce(reduce_func, actions, 0)
def getNeighborhood(self, m, r, p):
"""
获取该玩家当轮存活邻居
:param m: Match
:param r: Round ID
:param p: PID
:returns: Survivals list(neighborhood)
"""
if str(p) not in self.neighbors[m.name]:
print("Alone(%d)!" % p)
return []
return [i for i in self.survivals[m.name][str(r)] if i in self.neighbors[m.name][str(p)]]
def calcRoundData(self, m, r):
"""
计算某场比赛某一轮的theta值
:param m: Match
:param r: Round ID
:param p: PID
:returns: an average value and a detail dict
"""
r += 1
ans = {}
sigma = 0.0
for p in self.survivals[m.name][str(r)]:
e = max(0, min(1440, self.getNeighborTR(m, r, p, self.getNeighborhood(m, r, p))))
ans[p] = e
sigma += e
return (sigma, ans)
def calc_season(self, season, name):
"""
calc E_i,D
"""
avg = np.zeros(MAX_ROUND)
cnt = np.zeros(MAX_ROUND)
for m in season.data:
d = {}
cur_game_avg = []
game_end_at = int(m.query('game', 'created').first()['info']['game_end_at'])
for r in range(1, game_end_at + 1):
sigma, detail = self.calcRoundData(m, r - 1)
d[r] = detail
avg[r - 1] += sigma
cnt[r - 1] += len(detail)
cur_game_avg.append(sigma / len(detail))
self.details[m.name] = d
with open(f'outputs/THETA_{m.name}.csv', 'w') as f:
csv.writer(f).writerow(cur_game_avg)
print(cnt)
for i in range(MAX_ROUND):
if cnt[i] == 0:
cnt[i] = 1
avg /= cnt
with open(f'outputs/THETA_{name}.csv', 'w') as f:
csv.writer(f).writerow(avg)
return avg
def calc(self):
for s in self.seasons:
self.calc_season(s['season'], s['name'])
with open('outputs/THETA_new_detail.json', 'w') as f:
json.dump(self.details, f)
if __name__ == '__main__':
e = theta_r_i()
e.calc()