You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
6.3 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import time
import logging
import random
import re
import torch
import numpy as np
import pandas as pd
import datetime
from dateutil.relativedelta import relativedelta
def create_logger(log_path):
"""
将日志输出到日志文件和控制台
"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# 创建一个handler用于写入日志文件
file_handler = logging.FileHandler(filename=log_path)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
if file_handler not in logger.handlers:
logger.addHandler(file_handler)
# 创建一个handler用于将日志输出到控制台
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(formatter)
if console not in logger.handlers:
logger.addHandler(console)
return logger
def get_file_name(fname):
"""
获取文件名
"""
return os.path.split(fname)[-1].split(".")[0]
def get_file_size(fname):
"""
获取文件大小MB
"""
fsize = os.path.getsize(fname)
fsize = fsize/float(1024 * 1024)
return round(fsize, 2)
def set_seed(seed):
"""
设置随机数种子
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def reduce_mem(df):
"""
节省减少内存的一个函数
"""
starttime = time.time()
numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
start_mem = df.memory_usage().sum() / 1024**2
for col in df.columns:
col_type = df[col].dtypes
if col_type in numerics:
c_min = df[col].min()
c_max = df[col].max()
if pd.isnull(c_min) or pd.isnull(c_max):
continue
if str(col_type)[:3] == 'int':
if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
df[col] = df[col].astype(np.int8)
elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
df[col] = df[col].astype(np.int16)
elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
df[col] = df[col].astype(np.int32)
elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
df[col] = df[col].astype(np.int64)
else:
if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
df[col] = df[col].astype(np.float16)
elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
df[col] = df[col].astype(np.float32)
else:
df[col] = df[col].astype(np.float64)
end_mem = df.memory_usage().sum() / 1024**2
print('-- Mem. usage decreased to {:5.2f} Mb ({:.1f}% reduction),time spend:{:2.2f} min'.format(end_mem,
100*(start_mem-end_mem)/start_mem,
(time.time()-starttime)/60))
return df
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
def merge_dict(dict1, dict2):
"""
合并两个字典
"""
res = {**dict1, **dict2}
return res
def random_dict_order(dict_data):
"""
随机打乱字典顺序
"""
key_list = random.sample(dict_data.keys(), len(dict_data))
value_list = []
value_list.clear()
for key in key_list:
item_name = dict_data[key]
value_list.append(item_name)
results = {k: v for k, v in zip(key_list, value_list)}
return results
def get_before_date(n):
"""
获取前N天的日期
"""
today = datetime.datetime.now()
# 计算偏移量
offset = datetime.timedelta(days=-n)
re_date = (today + offset).strftime('%Y-%m-%d')
return re_date
def get_user(myobject_path):
"""
获取前三个月活跃用户
考虑到寒暑假的用户活度
561112月份取前三个月的活跃用户
17月份取前四个月的活跃用户
2348910取前五个月的活跃用户
传入myshixun.csv路径以获取实训的活跃用户
传入mysubjecy.csv路径以获得实践课程的活跃用户
"""
activate = pd.read_csv(myobject_path,sep='\t',encoding='utf-8')
activate["created_at"] = pd.to_datetime(activate["created_at"] )
if max(activate["created_at"]).month in (2,3,4,8,9,10):
activities = activate[activate["created_at"]>=max(activate["created_at"])-relativedelta(months=+5)]
elif max(activate["created_at"]).month in (1,7):
activities = activate[activate["created_at"]>=max(activate["created_at"])-relativedelta(months=+4)]
else:
activities = activate[activate["created_at"]>=max(activate["created_at"])-relativedelta(months=+3)]
user=activities["user_id"].unique()
return user
#删除特殊字符以及停用词等字符
def extract_word(wordlist):
wlist = []
word = str(wordlist)
r1 = r"[0-9\s\.\!\/_,$%^*(\"\']|[——!,。?、:?;;《》“”~@#¥%……&*][SEP]—"
# word=re.sub(r1,'',word)
word = re.sub(r1, '', word)
wlist.append(word)
return wlist
stop_list = ['','', '', '', '', '', '', ')','-', '|', ' ', '', '', '', '', '˘', '', ':', '', '', '',
'', ']', '[', '°', '', '', '☺️', '?',
'·', '×', '>', "`・∧・´", 'з', '',
'', '×', '', '', '', '''·', "[SEP]"]
def cut_stop(s):
words = ''
for word in s:
if word not in stop_list:
words += word
else:
words += ""
return words
def finalcut(word_list):
word = extract_word(word_list)
wlist = cut_stop(word)
return wlist