1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
| """音乐流派分类完整训练脚本"""
from datasets import load_dataset, Audio
from transformers import (
AutoFeatureExtractor,
AutoModelForAudioClassification,
TrainingArguments,
Trainer,
EarlyStoppingCallback
)
import evaluate
import numpy as np
# 1. 加载数据集
dataset = load_dataset("marsyas/gtzan", "all")
dataset = dataset["train"].train_test_split(test_size=0.2, seed=42, stratify_by_column="genre")
# 2. 配置
model_name = "facebook/wav2vec2-base"
labels = dataset["train"].features["genre"].names
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
# 3. 加载模型
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(
model_name,
num_labels=len(labels),
label2id=label2id,
id2label=id2label
)
# 4. 预处理
target_sr = feature_extractor.sampling_rate
def preprocess(examples):
audio = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio,
sampling_rate=target_sr,
max_length=target_sr * 30,
truncation=True,
padding=True,
return_tensors="pt"
)
return inputs
for split in dataset:
dataset[split] = dataset[split].cast_column("audio", Audio(sampling_rate=target_sr))
dataset[split] = dataset[split].map(preprocess, batched=True, batch_size=8, remove_columns=["audio", "file"])
# 5. 训练
accuracy = evaluate.load("accuracy")
def compute_metrics(pred):
preds = np.argmax(pred.predictions, axis=1)
return accuracy.compute(predictions=preds, references=pred.label_ids)
training_args = TrainingArguments(
output_dir="./gtzan-classifier",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=3e-5,
per_device_train_batch_size=8,
num_train_epochs=10,
warmup_ratio=0.1,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
fp16=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
trainer.train()
trainer.save_model("./gtzan-classifier-best")
|