之前数模参赛时遇到的一个问题,当时因为比赛的原因不方便马上发,现在发出来。

在本次数学建模中涉及到了logistic-regression模型的应用,简单记录一下Tensorflow的实现。

1、环境

我比较偷懒,直接使用了Google Cloud Platform免去了配置Tensorflow环境的麻烦。

2、代码

import pandas as pd                # 用于读取数据文件
import tensorflow as tf
import matplotlib.pyplot as plt    # 用于画图
import numpy as np

df = pd.read_csv("input.csv", header=None)
train_data = df.values

print(train_data)

train_X = train_data[:, :-1]
train_y = train_data[:, -1:]
feature_num = len(train_X[0])
sample_num = len(train_X)
print("Size of train_X: {}x{}".format(sample_num, feature_num))
print("Size of train_y: {}x{}".format(len(train_y), len(train_y[0])))

X = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

W = tf.Variable(tf.zeros([feature_num, 1]))
b = tf.Variable([-.9])

db = tf.matmul(X, tf.reshape(W, [-1, 1])) + b
hyp = tf.sigmoid(db)

cost0 = y * tf.log(hyp)
cost1 = (1 - y) * tf.log(1 - hyp)
cost = (cost0 + cost1) / -sample_num
loss = tf.reduce_sum(cost)

optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

feed_dict = {X: train_X, y: train_y}

for step in range(1000000):
    sess.run(train, {X: train_X, y: train_y})
    if step % 10000 == 0:
        print(step, sess.run(W).flatten(), sess.run(b).flatten())

# 绘图

w = [0.7672361 , -0.276697 , -0.19542742]
b = 0.09650069

from mpl_toolkits.mplot3d import Axes3D

x1 = train_data[:, 0]
x2 = train_data[:, 1]
x3 = train_data[:, 2]
y = train_data[:, -1:]

fig=plt.figure()
ax=Axes3D(fig)


for x1p, x2p, x3p, yp in zip(x1, x2, x3, y):
    if yp == 0:
        ax.scatter(x1p, x2p, x3p, c='r')
    else:
        ax.scatter(x1p, x2p, x3p, c='g')

ax.set_zlabel('Z')  # 坐标轴
ax.set_ylabel('Y')
ax.set_xlabel('X')


a = 0.7672361
b = -0.276697 
c = -0.19542742
d = 0.09650069

x1 = np.linspace(-1,1,10)
y1 = np.linspace(-1,1,10)

X,Y = np.meshgrid(x1,y1)
Z = (d - a*X - b*Y) / c

fig = plt.figure()
ax = fig.gca(projection='3d')

surf = ax.plot_surface(X, Y, Z)

结果

# Result
# 0 [ 0.7707945  -0.27244025 -0.19312732] [0.08507074]
# 10000 [ 0.77019846 -0.2730363  -0.19343074] [0.08675246]
# 20000 [ 0.76964736 -0.27363235 -0.19366762] [0.08829048]
# 30000 [ 0.7692501  -0.27412957 -0.19393419] [0.08960892]
# 40000 [ 0.7689034  -0.27445695 -0.19417214] [0.09070877]
# 50000 [ 0.7686214  -0.27475497 -0.19436422] [0.09161869]
# 60000 [ 0.7683936  -0.275053   -0.19451982] [0.09238537]
# 70000 [ 0.7682283  -0.27535102 -0.19466883] [0.09303201]
# 80000 [ 0.76810503 -0.27564904 -0.19481784] [0.09360377]
# 90000 [ 0.76798075 -0.27594706 -0.19491854] [0.09411547]
# 100000 [ 0.7678499  -0.2761182  -0.19500598] [0.09455331]
# 110000 [ 0.76773375 -0.27622774 -0.19508578] [0.09492187]
# 120000 [ 0.7676398  -0.27631634 -0.19515029] [0.0952199]
# 130000 [ 0.7675626  -0.27638906 -0.19520319] [0.09546462]
# 140000 [ 0.7674922  -0.27645552 -0.19525155] [0.09568814]
# 150000 [ 0.76744366 -0.2765014  -0.19528496] [0.09584232]
# 160000 [ 0.7673967  -0.27654564 -0.19531724] [0.09599134]
# 170000 [ 0.76735574 -0.27658433 -0.19534537] [0.09612132]
# 180000 [ 0.76733226 -0.2766065  -0.19536147] [0.09619582]
# 190000 [ 0.7673087  -0.27662855 -0.1953776 ] [0.09627033]
# 200000 [ 0.7672852  -0.27665073 -0.1953937 ] [0.09634484]
# 210000 [ 0.7672618  -0.2766729  -0.19540988] [0.09641934]
# 220000 [ 0.7672383  -0.2766951  -0.19542597] [0.09649385]
# 230000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 240000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 250000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 260000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 270000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 280000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 290000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 300000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 310000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 320000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 330000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 340000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 350000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 360000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]
# 370000 [ 0.7672361  -0.276697   -0.19542742] [0.09650069]

参考资料:
https://segmentfault.com/a/1190000009954640
http://bbs.bugcode.cn/t/20913