Skip to content

Commit 0dbd505

Browse files
authored
codegen: middleware snapshot tests (#502)
1 parent 1c1f3f0 commit 0dbd505

File tree

4 files changed

+197
-1
lines changed

4 files changed

+197
-1
lines changed

codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoStdlibTypes.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ private GoStdlibTypes() { }
2626

2727
public static final class Context {
2828
public static final Symbol Context = SmithyGoDependency.CONTEXT.valueSymbol("Context");
29+
public static final Symbol Background = SmithyGoDependency.CONTEXT.valueSymbol("Background");
2930
}
3031

3132
public static final class Fmt {
@@ -42,4 +43,8 @@ public static final class Http {
4243
public static final class Path {
4344
public static final Symbol Join = SmithyGoDependency.PATH.valueSymbol("Join");
4445
}
46+
47+
public static final class Testing {
48+
public static final Symbol T = SmithyGoDependency.TESTING.pointableSymbol("T");
49+
}
4550
}

codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ public final class GoWriter extends AbstractCodeWriter<GoWriter> {
6565
private final ImportDeclarations imports = new ImportDeclarations();
6666
private final List<SymbolDependency> dependencies = new ArrayList<>();
6767
private final boolean innerWriter;
68+
private final List<String> buildTags = new ArrayList<>();
6869

6970
private int docWrapLength = DEFAULT_DOC_WRAP_LENGTH;
7071
private AbstractCodeWriter<GoWriter> packageDocs;
@@ -93,6 +94,7 @@ private void init() {
9394
putFormatter('T', new GoSymbolFormatter());
9495
putFormatter('P', new PointableGoSymbolFormatter());
9596
putFormatter('W', new GoWritableInjector());
97+
putFormatter('D', new GoDependencyFormatter());
9698

9799
if (!innerWriter) {
98100
packageDocs = new GoWriter(this.fullPackageName, true);
@@ -881,6 +883,11 @@ public void write(Writable w) {
881883
write("$W", w);
882884
}
883885

886+
public GoWriter addBuildTag(String tag) {
887+
buildTags.add(tag);
888+
return this;
889+
}
890+
884891
@Override
885892
public String toString() {
886893
String contents = super.toString();
@@ -889,6 +896,9 @@ public String toString() {
889896
return contents;
890897
}
891898

899+
var tags = buildTags.isEmpty()
900+
? ""
901+
: "//go:build " + String.join(",", buildTags) + "\n";
892902

893903
String[] packageParts = fullPackageName.split("/");
894904
String header = String.format("// Code generated by smithy-go-codegen DO NOT EDIT.%n%n");
@@ -919,7 +929,7 @@ public String toString() {
919929
return header + strippedImportString + "\n" + strippedContents;
920930
}
921931

922-
return header + packageDocs + packageStatement + importString + contents;
932+
return header + packageDocs + tags + packageStatement + importString + contents;
923933
}
924934

925935
/**
@@ -1013,6 +1023,22 @@ public String apply(Object type, String indent) {
10131023
}
10141024
}
10151025

1026+
/**
1027+
* Implements Go symbol formatting for the {@code $D} formatter.
1028+
*/
1029+
private class GoDependencyFormatter implements BiFunction<Object, String, String> {
1030+
@Override
1031+
public String apply(Object type, String indent) {
1032+
if (type instanceof GoDependency) {
1033+
addUseImports((GoDependency) type);
1034+
} else {
1035+
throw new CodegenException(
1036+
"Invalid type provided to $D. Expected a GoDependency, but found `" + type + "`");
1037+
}
1038+
return "";
1039+
}
1040+
}
1041+
10161042
public interface Writable extends Consumer<GoWriter> {
10171043
}
10181044

codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public final class SmithyGoDependency {
3737
public static final GoDependency JSON = stdlib("encoding/json");
3838
public static final GoDependency IO = stdlib("io");
3939
public static final GoDependency IOUTIL = stdlib("io/ioutil");
40+
public static final GoDependency FS = stdlib("io/fs");
4041
public static final GoDependency CRYPTORAND = stdlib("crypto/rand", "cryptorand");
4142
public static final GoDependency TESTING = stdlib("testing");
4243
public static final GoDependency ERRORS = stdlib("errors");
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.smithy.go.codegen.integration;
17+
18+
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
19+
20+
import software.amazon.smithy.codegen.core.SymbolProvider;
21+
import software.amazon.smithy.go.codegen.GoDelegator;
22+
import software.amazon.smithy.go.codegen.GoSettings;
23+
import software.amazon.smithy.go.codegen.GoStdlibTypes;
24+
import software.amazon.smithy.go.codegen.GoWriter;
25+
import software.amazon.smithy.go.codegen.SmithyGoDependency;
26+
import software.amazon.smithy.go.codegen.SmithyGoTypes;
27+
import software.amazon.smithy.model.Model;
28+
import software.amazon.smithy.model.knowledge.TopDownIndex;
29+
import software.amazon.smithy.model.shapes.OperationShape;
30+
import software.amazon.smithy.model.shapes.ServiceShape;
31+
import software.amazon.smithy.utils.MapUtils;
32+
33+
public class MiddlewareStackSnapshotTests implements GoIntegration {
34+
@Override
35+
public void writeAdditionalFiles(
36+
GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator
37+
) {
38+
goDelegator.useFileWriter("snapshot_test.go", settings.getModuleName(), writer -> {
39+
writer.addBuildTag("snapshot");
40+
writer.write(commonTestSource());
41+
writer.write(snapshotTests(model, settings.getService(model), symbolProvider));
42+
writer.write(snapshotUpdaters(model, settings.getService(model), symbolProvider));
43+
});
44+
}
45+
46+
private GoWriter.Writable commonTestSource() {
47+
return goTemplate("""
48+
$os:D $fs:D $io:D $errors:D $fmt:D $middleware:D
49+
50+
const ssprefix = "snapshot"
51+
52+
type snapshotOK struct{}
53+
54+
func (snapshotOK) Error() string { return "error: success" }
55+
56+
func createp(path string) (*os.File, error) {
57+
if err := os.Mkdir(ssprefix, 0700); err != nil && !errors.Is(err, fs.ErrExist) {
58+
return nil, err
59+
}
60+
return os.Create(path)
61+
}
62+
63+
func sspath(op string) string {
64+
return fmt.Sprintf("%s/api_op_%s.go.snap", ssprefix, op)
65+
}
66+
67+
func updateSnapshot(stack *middleware.Stack, operation string) error {
68+
f, err := createp(sspath(operation))
69+
if err != nil {
70+
return err
71+
}
72+
defer f.Close()
73+
if _, err := f.Write([]byte(stack.String())); err != nil {
74+
return err
75+
}
76+
return snapshotOK{}
77+
}
78+
79+
func testSnapshot(stack *middleware.Stack, operation string) error {
80+
f, err := os.Open(sspath(operation))
81+
if errors.Is(err, fs.ErrNotExist) {
82+
return snapshotOK{}
83+
}
84+
if err != nil {
85+
return err
86+
}
87+
defer f.Close()
88+
expected, err := io.ReadAll(f)
89+
if err != nil {
90+
return err
91+
}
92+
if actual := stack.String(); actual != string(expected) {
93+
return fmt.Errorf("%s != %s", expected, actual)
94+
}
95+
return snapshotOK{}
96+
}
97+
""",
98+
MapUtils.of(
99+
"errors", SmithyGoDependency.ERRORS, "fmt", SmithyGoDependency.FMT,
100+
"fs", SmithyGoDependency.FS, "io", SmithyGoDependency.IO,
101+
"middleware", SmithyGoDependency.SMITHY_MIDDLEWARE, "os", SmithyGoDependency.OS
102+
));
103+
}
104+
105+
private GoWriter.Writable snapshotUpdaters(Model model, ServiceShape service, SymbolProvider symbolProvider) {
106+
return GoWriter.ChainWritable.of(
107+
TopDownIndex.of(model).getContainedOperations(service).stream()
108+
.map(it -> testUpdateSnapshot(it, symbolProvider))
109+
.toList()
110+
).compose();
111+
}
112+
113+
private GoWriter.Writable snapshotTests(Model model, ServiceShape service, SymbolProvider symbolProvider) {
114+
return GoWriter.ChainWritable.of(
115+
TopDownIndex.of(model).getContainedOperations(service).stream()
116+
.map(it -> testCheckSnapshot(it, symbolProvider))
117+
.toList()
118+
).compose();
119+
}
120+
121+
private GoWriter.Writable testUpdateSnapshot(OperationShape operation, SymbolProvider symbolProvider) {
122+
return goTemplate("""
123+
func TestUpdateSnapshot_$operation:L(t $testingT:P) {
124+
svc := New(Options{})
125+
_, err := svc.$operation:L($contextBackground:T(), nil, func(o *Options) {
126+
o.APIOptions = append(o.APIOptions, func(stack $middlewareStack:P) error {
127+
return updateSnapshot(stack, $operation:S)
128+
})
129+
})
130+
if _, ok := err.(snapshotOK); !ok && err != nil {
131+
t.Fatal(err)
132+
}
133+
}
134+
""",
135+
MapUtils.of(
136+
"testingT", GoStdlibTypes.Testing.T,
137+
"contextBackground", GoStdlibTypes.Context.Background,
138+
"middlewareStack", SmithyGoTypes.Middleware.Stack,
139+
"operation", symbolProvider.toSymbol(operation).getName()
140+
));
141+
}
142+
143+
private GoWriter.Writable testCheckSnapshot(OperationShape operation, SymbolProvider symbolProvider) {
144+
return goTemplate("""
145+
func TestCheckSnapshot_$operation:L(t $testingT:P) {
146+
svc := New(Options{})
147+
_, err := svc.$operation:L($contextBackground:T(), nil, func(o *Options) {
148+
o.APIOptions = append(o.APIOptions, func(stack $middlewareStack:P) error {
149+
return testSnapshot(stack, $operation:S)
150+
})
151+
})
152+
if _, ok := err.(snapshotOK); !ok && err != nil {
153+
t.Fatal(err)
154+
}
155+
}
156+
""",
157+
MapUtils.of(
158+
"testingT", GoStdlibTypes.Testing.T,
159+
"contextBackground", GoStdlibTypes.Context.Background,
160+
"middlewareStack", SmithyGoTypes.Middleware.Stack,
161+
"operation", symbolProvider.toSymbol(operation).getName()
162+
));
163+
}
164+
}

0 commit comments

Comments
 (0)