Python 训练集、测试集以及验证集切分方法:sklearn及手动切分

news/2024/7/24 9:08:00 标签: python, sklearn, 机器学习, pytorch, 深度学习, 算法

目录

方法一

方法二


需求目的:针对模型训练输入,按照6:2:2的比例进行训练集、测试集和验证集的划分。当前数据量约10万条。如果针对的是记录条数达上百万的数据集,可按照98:1:1的比例进行切分。

方法一:切分训练集和测试集,采用机器学习sklearn中的train_test_split()函数
方法二:切分训练集、测试集以及验证集,针对dataframe手动切分

方法一

采用Sklearn包中的sklearn.model_selection.train_test_split()函数,该函数功能是将原始数据按照比例切分为训练集和测试集。

python">函数形式:
sklearn.model_selection.train_test_split(*arrays, test_size=None, 
train_size=None, random_state=None, shuffle=True, stratify=None)

参数解读:
*arrays:等长的列表、数组或者dataframe等
test_size: 0和1之间,默认0.25
train_size: 0和1之间,默认1
random_state: 传递一个int值,以便在多个函数调用之间产生可复制的输出
shuffle: 拆分前是否进行洗牌
strafity: 是否对数据进行分层

返回结果:
输入序列的train test分割序列

例子

python">>>> import numpy as np
>>> from sklearn.model_selection import train_test_split
>>> X, y = np.arange(10).reshape((5, 2)), range(5)
>>> X
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]])
>>> list(y)
[0, 1, 2, 3, 4]

>>> X_train, X_test, y_train, y_test = train_test_split(
...     X, y, test_size=0.33, random_state=42)
...
>>> X_train
array([[4, 5],
       [0, 1],
       [6, 7]])
>>> y_train
[2, 0, 3]
>>> X_test
array([[2, 3],
       [8, 9]])
>>> y_test
[1, 4]

方法二

手动切分,代码如下。输入采用Python的DataFrame,同样输出三个文件。如果需要每次都输入同样的切分数据,可采用random.seed()定义随机数种子。

python">def split_train_test_valid():
    # read file
    input_path = "E:\\Data\\"
    file = "flow.csv"
    df_flow = pd.read_csv(input_path + file, header=None, encoding='gbk')

    # define the ratios 6:2:2
    train_len = int(len(df_flow) * 0.6)
    test_len = int(len(df_flow) * 0.2)

    # split the dataframe
    idx = list(df_flow.index)
    random.shuffle(idx)  # 将index列表打乱
    df_train = df_flow.loc[idx[:train_len]]
    df_test = df_flow.loc[idx[train_len:train_len+test_len]]
    df_valid = df_flow.loc[idx[train_len+test_len:]]  # 剩下的就是valid

    # output
    df_train.to_csv(input_path+'train.txt', header=False, index=False, sep='\t')
    df_test.to_csv(input_path+'test.txt', header=False, index=False, sep='\t')
    df_valid.to_csv(input_path+'valid.txt', header=False, index=False, sep='\t')

参考资料:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html 


 


http://www.niftyadmin.cn/n/5124878.html

相关文章

如何满足TIKTOK直播企业四大网络需求,轻松实现直播无卡顿?

说到企业海外直播,大家脑海里一下就想会想到当下爆火的TIKTOK平台,而随着TIKTOK在全球范围大面积铺开推广,不同国家不同市场的活跃用户数纷纷上涨,让更多的电商企业选择在TIKTOK上进行布局获得商机。 对于已布局TIKTOK直播或者待布…

浏览器是怎么执行JS的?——消息队列与事件循环

看完渡一的课后,感觉这块内容确实非常重要,写 JS 的连 JS 的执行原理都不知道可不行。 事件循环 在写 JS 的时候,你有没有想过 JS 是按照什么顺序执行的?浏览器是怎么执行 JS 代码的?为什么有时候代码没有按照我们认为…

3.SpringSecurity基于数据库的认证与授权

文章目录 SpringSecurity基于数据库的认证与授权一、自定义用户信息UserDetails1.1 新建用户信息类UserDetails1.2 UserDetailsService 二、基于数据库的认证2.1 连接数据库2.2 获取用户信息2.2.1 获取用户实体类2.2.2 Mapper2.2.3 Service 2.3 认证2.3.1 实现UserDetails接口2…

796. 子矩阵的和(左上角前缀和)

题目: 796. 子矩阵的和 - AcWing题库 思路: 1.暴力搜索(搜索时间复杂度为O(n2),很多时候会超时) 2. 前缀和(左上角前缀和):本题特殊在不是直接求前n个数的和,而是求…

华为云服务器 Ping 延迟

参考 检查云服务器的内核参数。 检查文件/etc/sysctl.conf中配置项“net.ipv4.icmp_echo_ignore_all”的值,0表示允许Ping,1表示禁止Ping。 允许PING设置。 临时允许PING操作的命令: #echo 0 >/proc/sys/net/ipv4/icmp_echo_ignore_all永…

两个难搞的Java Error/Exception

最近维护公司的产品时,我碰到了两个头痛的Java异常。未免以后忘记了,所以写篇blog记录下这些问题和解决方法。 Entity定义 由于不能展示公司的代码,我就用书店、书、作者这些对象来说明。书店与作者之间是m:n的关系,作者与书之间…

使用多线程无法收集到子线程的日志

问题:使用多线程的时候日志收集只能收集到主线程的,子线程的日志收集不到。 解决:创建多线程的时候使用org.slf4j.MDC把主线程的信息映射到子线程 package com.wechat.util;import cn.dotfashion.soa.sheintracing.concurrent.executor.Tra…