From 39a9a9d04785618b9a6b5e8c95e46163ecb02ee9 Mon Sep 17 00:00:00 2001
From: "zongxi.li" <lizx@wealthgrow.cn>
Date: Thu, 3 Dec 2020 17:34:47 +0800
Subject: [PATCH] =?UTF-8?q?=E8=AF=84=E4=BB=B7=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 app/api/engine.py                 |  12 +++
 app/config/config.yaml            |   1 +
 app/service/portfolio_diagnose.py | 146 +++++++++++++++++----------
 app/utils/fund_rank.py            | 157 ++++++++++++++++++------------
 4 files changed, 203 insertions(+), 113 deletions(-)

diff --git a/app/api/engine.py b/app/api/engine.py
index 0f70bf7..a79bc67 100644
--- a/app/api/engine.py
+++ b/app/api/engine.py
@@ -57,6 +57,18 @@ tamp_user_engine = create_engine(
     ),
     echo=True
 )
+
+tamp_fund_engine = create_engine(
+    'mysql+pymysql://{user}:{password}@{host}:{port}/{db}?charset={charset}'.format(
+        db=config[env]['MySQL']['tamp_fund_db'],
+        host=config[env]['MySQL']['host'],
+        port=config[env]['MySQL']['port'],
+        user=config[env]['MySQL']['user'],
+        password=config[env]['MySQL']['password'],
+        charset="utf8"
+    ),
+    echo=True
+)
 # tamp_product_session = scoped_session(sessionmaker(bind=tamp_product_engine))()
 # tamp_order_session = scoped_session(sessionmaker(bind=tamp_order_engine))()
 # tamp_user_session = scoped_session(sessionmaker(bind=tamp_user_engine))()
diff --git a/app/config/config.yaml b/app/config/config.yaml
index bddca31..07f84fc 100644
--- a/app/config/config.yaml
+++ b/app/config/config.yaml
@@ -43,6 +43,7 @@ prod:
     tamp_product_db: tamp_product
     tamp_order_db: tamp_order
     tamp_user_db: tamp_user
+    tamp_fund_db: tamp_fund
     host: tamper.mysql.polardb.rds.aliyuncs.com
     port: 3306
     user: tamp_admin
diff --git a/app/service/portfolio_diagnose.py b/app/service/portfolio_diagnose.py
index 818afbe..591a758 100644
--- a/app/service/portfolio_diagnose.py
+++ b/app/service/portfolio_diagnose.py
@@ -1,9 +1,15 @@
+# -*- coding: UTF-8 -*-
+# """
+# @author: Zongxi.Li
+# @file:portfolio_copy.py
+# @time:2020/12/03
+# """
 from app.utils.fund_rank import *
 from app.utils.risk_parity import *
 from app.pypfopt import risk_models
 from app.pypfopt import expected_returns
 from app.pypfopt import EfficientFrontier
-from app.api.engine import tamp_user_engine, tamp_product_engine, TAMP_SQL
+from app.api.engine import tamp_product_engine, tamp_fund_engine, TAMP_SQL
 
 
 def cal_correlation(prod):
@@ -15,10 +21,10 @@ def cal_correlation(prod):
     Returns:屏蔽基金与自身相关性的相关矩阵,因为基金与自身相关性为1,妨碍后续高相关性基金筛选的判断
 
     """
-    prod_return = prod.iloc[:, :].apply(lambda x: simple_return(x))
+    prod_return = prod.iloc[:, :].apply(lambda x: simple_return(x).astype(float))
     correlation = prod_return.corr()
     correlation = correlation.round(2)
-    return correlation.mask(np.eye(correlation.shape[0], dtype=np.bool))
+    return correlation.mask(np.eye(correlation.shape[0], dtype=np.bool_))
 
 
 def rename_col(df, fund_id):
@@ -66,7 +72,7 @@ def search_rank(fund_rank, fund, metric):
     return fund_rank[fund_rank['fund_id'] == fund][metric].values[0]
 
 
-def translate_single(content, evaluation):
+def translate_single(content, content_id, evaluation):
     '''
     content = [["优秀","良好","一般"],
            ["优秀","良好","合格","较差"],
@@ -74,7 +80,20 @@ def translate_single(content, evaluation):
            ["高","一般","较低"]]
     evaluation = [0,1,1,2]
     '''
-    return tuple([content[i][v] if type(v) == int else v for i, v in enumerate(evaluation)])
+    ret = []
+    for i, v in enumerate(evaluation):
+        if isinstance(v, str):
+            ret.append(v)
+            continue
+        elif content[content_id][i][v] in ["优秀", "良好", "高", "高于", "较好"]:
+            ret.append("""<span class="self_description_red">{}</span>""".format(content[content_id][i][v]))
+            continue
+        elif content_id == 4 and v == 0:
+            ret.append("""<span class="self_description_red">{}</span>""".format(content[content_id][i][v]))
+            continue
+        else:
+            ret.append("""<span class="self_description_green">{}</span>""".format(content[content_id][i][v]))
+    return tuple(ret)
 
 
 def choose_good_evaluation(evaluation):
@@ -127,10 +146,20 @@ def choose_bad_evaluation(evaluation):
 
 
 def get_fund_rank():
-    sql = "SELECT * FROM fund_rank"
-    df = pd.read_sql(sql, con)
-    # df = pd.read_csv('fund_rank.csv', encoding='gbk')
-    return df
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        sql = "SELECT * FROM fund_rank"
+
+        # df = pd.read_sql(sql, con)
+        # df = pd.read_csv('fund_rank.csv', encoding='gbk')
+        cur = tamp_product_session.execute(sql)
+        data = cur.fetchall()
+        df = pd.DataFrame(list(data), columns=['index', 'fund_id', 'range_return', 'annual_return', 'max_drawdown',
+                                               'sharp_ratio', 'volatility', 'sortino_ratio', 'downside_risk',
+                                               'substrategy', 'manager', 'annual_return_rank', 'downside_risk_rank',
+                                               'max_drawdown_rank', 'sharp_ratio_rank', 'z_score'])
+        df.drop('index', axis=1, inplace=True)
+        return df
 
 
 def get_index_daily(index_id):
@@ -142,13 +171,19 @@ def get_index_daily(index_id):
     Returns:与组合净值形式相同的表
 
     """
-    sql = "SELECT ts_code, trade_date, close FROM index_daily WHERE ts_code='{}'".format(index_id)
-    df = pd.read_sql(sql, con).dropna(how='any')
-    df.rename({'ts_code': 'fund_id', 'trade_date': 'end_date', 'close': 'adj_nav'}, axis=1, inplace=True)
-    df['end_date'] = pd.to_datetime(df['end_date'])
-    df.set_index('end_date', drop=True, inplace=True)
-    df.sort_index(inplace=True, ascending=True)
-    df = rename_col(df, index_id)
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        sql = "SELECT ts_code, trade_date, close FROM index_daily WHERE ts_code='{}'".format(index_id)
+        # df = pd.read_sql(sql, con).dropna(how='any')
+        cur = tamp_product_session.execute(sql)
+        data = cur.fetchall()
+
+        df = pd.DataFrame(list(data), columns=['ts_code', 'trade_code', ' close'])
+        df.rename({'ts_code': 'fund_id', 'trade_date': 'end_date', 'close': 'adj_nav'}, axis=1, inplace=True)
+        df['end_date'] = pd.to_datetime(df['end_date'])
+        df.set_index('end_date', drop=True, inplace=True)
+        df.sort_index(inplace=True, ascending=True)
+        df = rename_col(df, index_id)
     return df
 
 
@@ -158,9 +193,14 @@ def get_tamp_fund():
     Returns:
 
     """
-    sql = "SELECT id FROM tamp_fund_info WHERE id LIKE 'HF%'"
-    df = pd.read_sql(sql, con)
-    df.rename({'id': 'fund_id'}, axis=1, inplace=True)
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        sql = "SELECT id FROM tamp_fund_info WHERE id LIKE 'HF%'"
+        cur = tamp_product_session.execute(sql)
+        data = cur.fetchall()
+        # df = pd.read_sql(sql, con)
+        df = pd.DataFrame(list(data), columns=['id'])
+        df.rename({'id': 'fund_id'}, axis=1, inplace=True)
     return df
 
 
@@ -203,9 +243,14 @@ def get_radar_data(fund):
 
 
 def get_fund_name(fund):
-    sql = "SELECT fund_short_name FROM fund_info WHERE id='{}'".format(fund)
-    df = pd.read_sql(sql, con)
-    return df
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        sql = "SELECT fund_short_name FROM fund_info WHERE id='{}'".format(fund)
+        # df = pd.read_sql(sql, con)
+        cur = tamp_product_session.execute(sql)
+        data = cur.fetchall()
+        df = pd.DataFrame(list(data), columns=['fund_short_name'])
+        return df
 
 
 # 获取排名信息
@@ -336,12 +381,12 @@ class PortfolioDiagnose(object):
             # 建议替换得分为60或与其他基金相关度大于0.8的基金
             if z_score < 60:
                 self.abandon_fund_score.append(fund)
-                prod = prod.drop(fund, axis=1)
+                continue
 
-            if np.any(self.old_correlation[fund] > 0.8):
+            elif np.any(self.old_correlation[fund] > 0.8):
                 self.abandon_fund_corr.append(fund)
-                prod = prod.drop(fund, axis=1)
 
+        prod = prod.drop(self.abandon_fund_score + self.abandon_fund_corr, axis=1)
         self.old_correlation = self.old_correlation.fillna(1).round(2)
         self.old_correlation.columns = self.old_correlation.columns.map(lambda x: get_fund_name(x).values[0][0])
         self.old_correlation.index = self.old_correlation.index.map(lambda x: get_fund_name(x).values[0][0])
@@ -390,7 +435,7 @@ class PortfolioDiagnose(object):
                 proposal_nav = rename_col(proposal_nav, proposal)
 
                 # 按最大周期进行重采样,计算新建组合的相关性
-                prod = pd.merge(prod, proposal_nav, how='outer', on='end_date')
+                prod = pd.merge(prod, proposal_nav, how='outer', on='end_date').astype(float)
                 prod.sort_index(inplace=True)
                 prod.ffill(inplace=True)
                 prod = resample(prod, get_trade_cal(), min(self.freq_list))
@@ -442,7 +487,6 @@ class PortfolioDiagnose(object):
             #                           (set(self.proposal_fund) | set(self.replace_pair.values())))
             # propose_portfolio.drop()
 
-
         propose_risk_mapper = dict()
         for fund in propose_portfolio.columns:
             propose_risk_mapper[fund] = str(get_risk_level(search_rank(fund_rank, fund, metric='substrategy')))
@@ -526,13 +570,14 @@ class PortfolioDiagnose(object):
             num = len(fund_rank_re)
             fund_id_rank_list = list(fund_rank_re["fund_id"])
             for f_id in fund_id_rank_list:
-                name = data_adaptor.user_customer_order_df[data_adaptor.user_customer_order_df["fund_id"] == f_id]["fund_name"].values[0]
+                name = data_adaptor.user_customer_order_df[data_adaptor.user_customer_order_df["fund_id"] == f_id][
+                    "fund_name"].values[0]
                 return_rank_evaluate = return_rank_evaluate + name + "、"
-            return_rank_evaluate = return_rank_evaluate[:-1] +"等" + str(num) + "只产品稳健,对组合的收益率贡献明显,"
+            return_rank_evaluate = return_rank_evaluate[:-1] + "等" + str(num) + "只产品稳健,对组合的收益率贡献明显,"
 
         # 正收益基金数量
         group_hold_data = pd.DataFrame(group_result[group_name]["group_hoding_info"])
-        profit_positive_num = group_hold_data[group_hold_data["profit"]>0]["profit"].count()
+        profit_positive_num = group_hold_data[group_hold_data["profit"] > 0]["profit"].count()
         if profit_positive_num > 0:
             profit_positive_evaluate = str(profit_positive_num) + "只基金取的正收益,"
         else:
@@ -548,28 +593,29 @@ class PortfolioDiagnose(object):
         else:
             no_data_fund_evaluate = "ï¼›"
 
-        group_order_df = data_adaptor.user_customer_order_df[data_adaptor.user_customer_order_df["folio_name"] == group_name]
+        group_order_df = data_adaptor.user_customer_order_df[
+            data_adaptor.user_customer_order_df["folio_name"] == group_name]
         strategy_list = group_order_df["substrategy"]
         uniqe_strategy = list(strategy_list.unique())
         uniqe_strategy_name = [dict_substrategy[int(x)] + "、" for x in uniqe_strategy]
         # 覆盖的基金名称
         strategy_name_evaluate = "".join(uniqe_strategy_name)[:-1]
 
-
-        if len(uniqe_strategy)/float(len(strategy_list)) > 0.6:
+        if len(uniqe_strategy) / float(len(strategy_list)) > 0.6:
             strategy_distribution_evaluate = "策略上有一定分散"
         else:
             strategy_distribution_evaluate = "策略分散程度不高"
         # 相关性
         if len(self.abandon_fund_corr) > 0:
-            fund_corr_name = [str(group_order_df[group_order_df["fund_id"] == f_id]["fund_name"].values[0]) + "和" for f_id in self.abandon_fund_corr]
+            fund_corr_name = [str(group_order_df[group_order_df["fund_id"] == f_id]["fund_name"].values[0]) + "和" for
+                              f_id in self.abandon_fund_corr]
             fund_corr_evaluate = "".join(fund_corr_name)[:-1] + "相关性较高,建议调整组合配比;"
         else:
             fund_corr_evaluate = "ï¼›"
 
         num_fund = len(self.portfolio)
         evaluate_enum = [["优秀", "良好", "一般"],
-         ["优秀", "良好", "合格", "较差"]]
+                         ["优秀", "良好", "合格", "较差"]]
 
         z_score_evaluate = evaluate_enum[0][z_score_level]
         drawdown_evaluate = evaluate_enum[1][drawdown_level]
@@ -605,7 +651,7 @@ class PortfolioDiagnose(object):
         hold_info = group_result_data["group_hoding_info"]
 
         # 原组合总市值, 区间收益, 年化收益,	波动率,	最大回撤,	夏普比率
-        total_asset = round(pd.DataFrame(hold_info)["market_values"].sum(),2)
+        total_asset = round(pd.DataFrame(hold_info)["market_values"].sum(), 2)
         old_return = group_result_data["cumulative_return"]
         old_return_ratio_year = group_result_data["return_ratio_year"]
         old_volatility = group_result_data["volatility"]
@@ -625,9 +671,6 @@ class PortfolioDiagnose(object):
         propose_fund_df = product_df[product_df["fund_id"].isin(propose_fund_id_list)]
         propose_fund_id_list_name = []  # 基金名称,策略分级
 
-
-
-
         # hold_fund = set(self.portfolio) - set(self.abandon_fund)
         # abandon_fund = self.abandon_fund
         # proposal_fund = self.proposal_fund
@@ -733,10 +776,13 @@ class PortfolioDiagnose(object):
         else:
             evaluation = choose_bad_evaluation(data)
 
-        ret = ""
+        ret = []
+        i = 1
         for k, v in evaluation.items():
-            # print(translate_single(content[k], v))
-            ret = ret + sentence[k] % translate_single(content[k], v)
+            print(k, v)
+            single_sentence = str(i) + "." + sentence[k] % translate_single(content, k, v)
+            ret.append(single_sentence)
+            i += 1
         fund_name = get_fund_name(fund_id).values[0][0]
         return {'name': fund_name, 'data': ret}
 
@@ -774,11 +820,11 @@ class PortfolioDiagnose(object):
 portfolio = ['HF00002JJ2', 'HF00005DBQ', 'HF0000681Q', 'HF00006693', 'HF00006AZF', 'HF00006BGS']
 portfolio_diagnose = PortfolioDiagnose(client_type=1, portfolio=portfolio, invest_amount=10000000)
 portfolio_diagnose.optimize()
-# if __name__ == '__main__':
-#     print(portfolio_diagnose.single_fund_radar())
-    # print(portfolio_diagnose.propose_fund_radar())
-    # print(portfolio_diagnose.old_portfolio_evaluation())
-    # print('旧组合相关性:', portfolio_diagnose.old_correlation)
-    # print('新组合相关性:', portfolio_diagnose.new_correlation)
-    # print('旧组合个基评价:', portfolio_diagnose.old_portfolio_evaluation())
-    # print('新组合个基评价:', portfolio_diagnose.propose_fund_evaluation())
+if __name__ == '__main__':
+    print(portfolio_diagnose.single_fund_radar())
+    print(portfolio_diagnose.propose_fund_radar())
+    print(portfolio_diagnose.old_portfolio_evaluation())
+    print('旧组合相关性:', portfolio_diagnose.old_correlation)
+    print('新组合相关性:', portfolio_diagnose.new_correlation)
+    print('旧组合个基评价:', portfolio_diagnose.old_portfolio_evaluation())
+    print('新组合个基评价:', portfolio_diagnose.propose_fund_evaluation())
diff --git a/app/utils/fund_rank.py b/app/utils/fund_rank.py
index 9c25a00..df51c49 100644
--- a/app/utils/fund_rank.py
+++ b/app/utils/fund_rank.py
@@ -1,17 +1,18 @@
-# import pymysql
+
 from sqlalchemy import create_engine
 
-db = create_engine(
-    'mysql+pymysql://tamp_fund:@imeng408@tamper.mysql.polardb.rds.aliyuncs.com:3306/tamp_fund?charset=utf8mb4',
-    pool_size=50,
-    pool_recycle=3600,
-    pool_pre_ping=True)
-con = db.connect()
 
-import logging
+# db = create_engine(
+#     'mysql+pymysql://tamp_fund:@imeng408@tamper.mysql.polardb.rds.aliyuncs.com:3306/tamp_fund?charset=utf8mb4',
+#     pool_size=50,
+#     pool_recycle=3600,
+#     pool_pre_ping=True)
+# con = db.connect()
 
+import logging
 logging.basicConfig(level=logging.INFO)
 
+from app.api.engine import tamp_fund_engine, TAMP_SQL
 from app.utils.week_evaluation import *
 
 
@@ -35,32 +36,39 @@ def get_nav(fund, start_date, rollback=False, invest_type='public'):
     Returns:df[DataFrame]: 索引为净值公布日, 列为复权净值的净值表; 查询失败则返回None
 
     """
-    if invest_type == 'public':
-        sql = "SELECT ts_code, end_date, adj_nav FROM public_fund_nav " \
-              "WHERE ts_code='{}'".format(fund)
-        df = pd.read_sql(sql, con).dropna(how='any')
-        df.rename({'ts_code': 'fund_id'}, axis=1, inplace=True)
-    else:
-        sql = "SELECT fund_id, price_date, cumulative_nav FROM fund_nav " \
-              "WHERE fund_id='{}'".format(fund)
-        df = pd.read_sql(sql, con).dropna(how='any')
-        df.rename({'price_date': 'end_date', 'cumulative_nav': 'adj_nav'}, axis=1, inplace=True)
-
-    if df['adj_nav'].count() == 0:
-        logging.log(logging.ERROR, "CAN NOT FIND {}".format(fund))
-        return None
-
-    df['end_date'] = pd.to_datetime(df['end_date'])
-
-    if rollback and df['end_date'].min() < start_date < df['end_date'].max():
-        while start_date not in list(df['end_date']):
-            start_date -= datetime.timedelta(days=1)
-
-    df = df[df['end_date'] >= start_date]
-    df.drop_duplicates(subset='end_date', inplace=True, keep='first')
-    df.set_index('end_date', inplace=True)
-    df.sort_index(inplace=True, ascending=True)
-    return df
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        if invest_type == 'public':
+            sql = "SELECT ts_code, end_date, adj_nav FROM public_fund_nav " \
+                  "WHERE ts_code='{}'".format(fund)
+            cur = tamp_product_session.execute(sql)
+            data = cur.fetchall()
+            df = pd.DataFrame(list(data), columns=['ts_code', 'end_date', 'adj_nav']).dropna(how='any')
+            df.rename({'ts_code': 'fund_id'}, axis=1, inplace=True)
+        else:
+            sql = "SELECT fund_id, price_date, cumulative_nav FROM fund_nav " \
+                  "WHERE fund_id='{}'".format(fund)
+            # df = pd.read_sql(sql, con).dropna(how='any')
+            cur = tamp_product_session.execute(sql)
+            data = cur.fetchall()
+            df = pd.DataFrame(data, columns=['fund_id', 'price_date', 'cumulative_nav']).dropna(how='any')
+            df.rename({'price_date': 'end_date', 'cumulative_nav': 'adj_nav'}, axis=1, inplace=True)
+
+        if df['adj_nav'].count() == 0:
+            logging.log(logging.ERROR, "CAN NOT FIND {}".format(fund))
+            return None
+
+        df['end_date'] = pd.to_datetime(df['end_date'])
+
+        if rollback and df['end_date'].min() < start_date < df['end_date'].max():
+            while start_date not in list(df['end_date']):
+                start_date -= datetime.timedelta(days=1)
+
+        df = df[df['end_date'] >= start_date]
+        df.drop_duplicates(subset='end_date', inplace=True, keep='first')
+        df.set_index('end_date', inplace=True)
+        df.sort_index(inplace=True, ascending=True)
+        return df
 
 
 def get_frequency(df):
@@ -97,11 +105,16 @@ def get_trade_cal():
     Returns:df[DataFrame]: 索引为交易日, 列为交易日的上交所交易日历表
 
     """
-    sql = 'SELECT cal_date FROM stock_trade_cal WHERE is_open=1'
-    df = pd.read_sql(sql, con)
-    df['end_date'] = pd.to_datetime(df['cal_date'])
-    df.set_index('end_date', drop=False, inplace=True)
-    return df
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        sql = 'SELECT cal_date FROM stock_trade_cal WHERE is_open=1'
+        cur = tamp_product_session.execute(sql)
+        data = cur.fetchall()
+        df = pd.DataFrame(list(data), columns=['cal_date']).dropna(how='all')
+        # df = pd.read_sql(sql, con)
+        df['end_date'] = pd.to_datetime(df['cal_date'])
+        df.set_index('end_date', drop=False, inplace=True)
+        return df
 
 
 def get_manager(invest_type):
@@ -113,13 +126,21 @@ def get_manager(invest_type):
     Returns:
 
     """
-    if invest_type == 'public':
-        sql = 'SELECT ts_code, name FROM public_fund_manager WHERE end_date IS NULL'
-        df = pd.read_sql(sql, con)
-    else:
-        sql = 'SELECT fund_id, fund_manager_id FROM fund_manager_mapping'
-        df = pd.read_sql(sql, con)
-    return df
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        if invest_type == 'public':
+            sql = 'SELECT ts_code, name FROM public_fund_manager WHERE end_date IS NULL'
+            # df = pd.read_sql(sql, con)
+            cur = tamp_product_session.execute(sql)
+            data = cur.fetchall()
+            df = pd.DataFrame(list(data), columns=['ts_code', 'name'])
+        else:
+            sql = 'SELECT fund_id, fund_manager_id FROM fund_manager_mapping'
+            # df = pd.read_sql(sql, con)
+            cur = tamp_product_session.execute(sql)
+            data = cur.fetchall()
+            df = pd.DataFrame(list(data), columns=['fund_id', 'fund_manager_id'])
+        return df
 
 
 def get_fund_info(end_date, invest_type):
@@ -132,23 +153,33 @@ def get_fund_info(end_date, invest_type):
     Returns:
         [type]: [description]
     """
-    if invest_type == 'public':
-        sql = "SELECT ts_code, fund_type, management FROM public_fund_basic " \
-              "WHERE delist_date IS NULL AND (due_date IS NULL OR due_date>'{}')".format(end_date.strftime('%Y%m%d'))
-        df = pd.read_sql(sql, con).dropna(how='all')
-        manager_info = get_manager(invest_type)
-
-        df.rename({'ts_code': 'fund_id'}, axis=1, inplace=True)
-        df = pd.merge(df, manager_info, how="left", on='fund_id')
-    else:
-        sql = "SELECT id, substrategy FROM fund_info WHERE delete_tag=0 " \
-              "AND substrategy!=-1"
-        df = pd.read_sql(sql, con).dropna(how='all')
+    with TAMP_SQL(tamp_fund_engine) as tamp_product:
+        tamp_product_session = tamp_product.session
+        if invest_type == 'public':
+            sql = "SELECT ts_code, fund_type, management FROM public_fund_basic " \
+                  "WHERE delist_date IS NULL AND (due_date IS NULL OR due_date>'{}')".format(end_date.strftime('%Y%m%d'))
+            # df = pd.read_sql(sql, con).dropna(how='all')
+            cur = tamp_product_session.execute(sql)
+            data = cur.fetchall()
+
+            df = pd.DataFrame(list(data), columns=['ts_code', 'fund_type', 'management'])
+            manager_info = get_manager(invest_type)
+
+            df.rename({'ts_code': 'fund_id'}, axis=1, inplace=True)
+            df = pd.merge(df, manager_info, how="left", on='fund_id')
+        else:
 
-        df.rename({'id': 'fund_id'}, axis=1, inplace=True)
-        manager_info = get_manager(invest_type)
-        df = pd.merge(df, manager_info, how="inner", on='fund_id')
-    return df
+            sql = "SELECT id, substrategy FROM fund_info WHERE delete_tag=0 " \
+                  "AND substrategy!=-1"
+            cur = tamp_product_session.execute(sql)
+            data = cur.fetchall()
+            df = pd.DataFrame(list(data), columns=['id', 'substrategy'])
+            # df = pd.read_sql(sql, con).dropna(how='all')
+
+            df.rename({'id': 'fund_id'}, axis=1, inplace=True)
+            manager_info = get_manager(invest_type)
+            df = pd.merge(df, manager_info, how="inner", on='fund_id')
+        return df
 
 
 def resample(df, trading_cal, freq, simple_flag=True):
@@ -301,4 +332,4 @@ if __name__ == '__main__':
     # fund_rank.to_csv("fund_rank.csv", encoding='gbk')
     # df = pd.read_csv('fund_rank.csv')
     # df.to_sql("fund_rank", con, if_exists='replace')
-    con.close()
+    # con.close()
-- 
2.18.1